From 74bde48eaacf6c42a9cb4ebbdb9beb9e76f16a2e Mon Sep 17 00:00:00 2001 From: Kavit Shah Date: Fri, 28 Jul 2023 12:41:37 -0700 Subject: [PATCH] Initial commit --- .circleci/config.yml | 29 + .flake8 | 12 + .gitignore | 9 + .gitmodules | 12 + .pre-commit-config.yaml | 46 + CODE_OF_CONDUCT.md | 80 ++ CONTRIBUTING.md | 31 + LICENSE | 21 + README.md | 207 +++++ bd_spot_wrapper/.gitignore | 3 + bd_spot_wrapper/README.md | 40 + bd_spot_wrapper/__init__.py | 3 + bd_spot_wrapper/data/depth_transforms.txt | 11 + bd_spot_wrapper/data/image_sources.txt | 329 +++++++ bd_spot_wrapper/data/transforms_example.txt | 89 ++ bd_spot_wrapper/generate_executables.py | 41 + bd_spot_wrapper/requirements.txt | 3 + bd_spot_wrapper/setup.py | 16 + bd_spot_wrapper/spot_wrapper/__init__.py | 3 + bd_spot_wrapper/spot_wrapper/draw_square.py | 43 + bd_spot_wrapper/spot_wrapper/estop.py | 233 +++++ .../spot_wrapper/headless_estop.py | 112 +++ bd_spot_wrapper/spot_wrapper/home_robot.py | 10 + .../spot_wrapper/keyboard_teleop.py | 173 ++++ .../spot_wrapper/monitor_nav_pose.py | 21 + bd_spot_wrapper/spot_wrapper/roll_over.py | 26 + bd_spot_wrapper/spot_wrapper/selfright.py | 26 + bd_spot_wrapper/spot_wrapper/sit.py | 26 + bd_spot_wrapper/spot_wrapper/spot.py | 787 ++++++++++++++++ bd_spot_wrapper/spot_wrapper/stand.py | 26 + bd_spot_wrapper/spot_wrapper/utils.py | 102 +++ .../spot_wrapper/view_arm_proprioception.py | 25 + bd_spot_wrapper/spot_wrapper/view_camera.py | 77 ++ .../spot_wrapper/view_camera_and_record.py | 102 +++ installation/ISSUES.md | 220 +++++ installation/SETUP_INSTRUCTIONS.md | 237 +++++ installation/environment.yml | 594 +++++++++++++ spot_rl_experiments/.gitignore | 7 + spot_rl_experiments/README.md | 13 + spot_rl_experiments/__init__.py | 3 + spot_rl_experiments/configs/config.yaml | 70 ++ .../configs/ros_topic_names.yaml | 22 + .../comparisons/gaze_all_objects.py | 46 + .../experiments/comparisons/multiple_gaze.py | 48 + .../experiments/comparisons/nav_compare.py | 154 ++++ spot_rl_experiments/generate_executables.py | 44 + spot_rl_experiments/setup.py | 16 + .../spot_rl/baselines/go_to_waypoint.py | 70 ++ spot_rl_experiments/spot_rl/envs/base_env.py | 838 ++++++++++++++++++ spot_rl_experiments/spot_rl/envs/gaze_env.py | 155 ++++ spot_rl_experiments/spot_rl/envs/lang_env.py | 421 +++++++++ .../spot_rl/envs/mobile_manipulation_env.py | 419 +++++++++ spot_rl_experiments/spot_rl/envs/nav_env.py | 93 ++ spot_rl_experiments/spot_rl/envs/place_env.py | 91 ++ spot_rl_experiments/spot_rl/launch/core.sh | 19 + .../spot_rl/launch/kill_sessions.sh | 9 + .../spot_rl/launch/local_listener.sh | 7 + .../spot_rl/launch/local_only.sh | 17 + spot_rl_experiments/spot_rl/llm/.gitignore | 130 +++ spot_rl_experiments/spot_rl/llm/README.md | 36 + .../spot_rl/llm/src/conf/config.yaml | 5 + .../spot_rl/llm/src/conf/llm/openai.yaml | 41 + .../conf/prompt/rearrange_easy_few_shot.yaml | 16 + .../conf/prompt/rearrange_easy_zero_shot.yaml | 10 + .../spot_rl/llm/src/notebook.ipynb | 54 ++ .../spot_rl/llm/src/rearrange_llm.py | 83 ++ .../spot_rl/models/__init__.py | 7 + spot_rl_experiments/spot_rl/models/owlvit.py | 246 +++++ .../spot_rl/models/sentence_similarity.py | 78 ++ spot_rl_experiments/spot_rl/real_policy.py | 310 +++++++ spot_rl_experiments/spot_rl/ros_img_vis.py | 279 ++++++ spot_rl_experiments/spot_rl/spot_ros_node.py | 331 +++++++ spot_rl_experiments/spot_rl/utils/autodock.py | 11 + .../spot_rl/utils/depth_map_utils.py | 324 +++++++ .../spot_rl/utils/generate_place_goal.py | 38 + .../spot_rl/utils/helper_nodes.py | 89 ++ .../spot_rl/utils/img_publishers.py | 486 ++++++++++ .../spot_rl/utils/mask_rcnn_utils.py | 80 ++ .../spot_rl/utils/remote_spot.py | 144 +++ .../spot_rl/utils/remote_spot_listener.py | 95 ++ .../spot_rl/utils/robot_subscriber.py | 98 ++ .../utils/run_local_parallel_inference.py | 23 + .../spot_rl/utils/run_parallel_inference.py | 25 + .../spot_rl/utils/spot_rl_launch_local.py | 3 + .../spot_rl/utils/stopwatch.py | 35 + spot_rl_experiments/spot_rl/utils/utils.py | 122 +++ .../spot_rl/utils/waypoint_recorder.py | 391 ++++++++ .../spot_rl/utils/whisper_translator.py | 157 ++++ third_party/DeblurGANv2 | 1 + third_party/habitat-lab | 1 + third_party/mask_rcnn_detectron2 | 1 + 91 files changed, 10137 insertions(+) create mode 100644 .circleci/config.yml create mode 100644 .flake8 create mode 100644 .gitignore create mode 100644 .gitmodules create mode 100644 .pre-commit-config.yaml create mode 100644 CODE_OF_CONDUCT.md create mode 100644 CONTRIBUTING.md create mode 100644 LICENSE create mode 100644 README.md create mode 100644 bd_spot_wrapper/.gitignore create mode 100644 bd_spot_wrapper/README.md create mode 100644 bd_spot_wrapper/__init__.py create mode 100644 bd_spot_wrapper/data/depth_transforms.txt create mode 100644 bd_spot_wrapper/data/image_sources.txt create mode 100644 bd_spot_wrapper/data/transforms_example.txt create mode 100644 bd_spot_wrapper/generate_executables.py create mode 100644 bd_spot_wrapper/requirements.txt create mode 100644 bd_spot_wrapper/setup.py create mode 100644 bd_spot_wrapper/spot_wrapper/__init__.py create mode 100644 bd_spot_wrapper/spot_wrapper/draw_square.py create mode 100644 bd_spot_wrapper/spot_wrapper/estop.py create mode 100644 bd_spot_wrapper/spot_wrapper/headless_estop.py create mode 100644 bd_spot_wrapper/spot_wrapper/home_robot.py create mode 100644 bd_spot_wrapper/spot_wrapper/keyboard_teleop.py create mode 100644 bd_spot_wrapper/spot_wrapper/monitor_nav_pose.py create mode 100644 bd_spot_wrapper/spot_wrapper/roll_over.py create mode 100644 bd_spot_wrapper/spot_wrapper/selfright.py create mode 100644 bd_spot_wrapper/spot_wrapper/sit.py create mode 100644 bd_spot_wrapper/spot_wrapper/spot.py create mode 100644 bd_spot_wrapper/spot_wrapper/stand.py create mode 100644 bd_spot_wrapper/spot_wrapper/utils.py create mode 100644 bd_spot_wrapper/spot_wrapper/view_arm_proprioception.py create mode 100644 bd_spot_wrapper/spot_wrapper/view_camera.py create mode 100644 bd_spot_wrapper/spot_wrapper/view_camera_and_record.py create mode 100644 installation/ISSUES.md create mode 100644 installation/SETUP_INSTRUCTIONS.md create mode 100644 installation/environment.yml create mode 100644 spot_rl_experiments/.gitignore create mode 100644 spot_rl_experiments/README.md create mode 100644 spot_rl_experiments/__init__.py create mode 100644 spot_rl_experiments/configs/config.yaml create mode 100644 spot_rl_experiments/configs/ros_topic_names.yaml create mode 100644 spot_rl_experiments/experiments/comparisons/gaze_all_objects.py create mode 100644 spot_rl_experiments/experiments/comparisons/multiple_gaze.py create mode 100644 spot_rl_experiments/experiments/comparisons/nav_compare.py create mode 100644 spot_rl_experiments/generate_executables.py create mode 100644 spot_rl_experiments/setup.py create mode 100644 spot_rl_experiments/spot_rl/baselines/go_to_waypoint.py create mode 100644 spot_rl_experiments/spot_rl/envs/base_env.py create mode 100644 spot_rl_experiments/spot_rl/envs/gaze_env.py create mode 100644 spot_rl_experiments/spot_rl/envs/lang_env.py create mode 100644 spot_rl_experiments/spot_rl/envs/mobile_manipulation_env.py create mode 100644 spot_rl_experiments/spot_rl/envs/nav_env.py create mode 100644 spot_rl_experiments/spot_rl/envs/place_env.py create mode 100644 spot_rl_experiments/spot_rl/launch/core.sh create mode 100644 spot_rl_experiments/spot_rl/launch/kill_sessions.sh create mode 100644 spot_rl_experiments/spot_rl/launch/local_listener.sh create mode 100644 spot_rl_experiments/spot_rl/launch/local_only.sh create mode 100644 spot_rl_experiments/spot_rl/llm/.gitignore create mode 100644 spot_rl_experiments/spot_rl/llm/README.md create mode 100644 spot_rl_experiments/spot_rl/llm/src/conf/config.yaml create mode 100644 spot_rl_experiments/spot_rl/llm/src/conf/llm/openai.yaml create mode 100644 spot_rl_experiments/spot_rl/llm/src/conf/prompt/rearrange_easy_few_shot.yaml create mode 100644 spot_rl_experiments/spot_rl/llm/src/conf/prompt/rearrange_easy_zero_shot.yaml create mode 100644 spot_rl_experiments/spot_rl/llm/src/notebook.ipynb create mode 100644 spot_rl_experiments/spot_rl/llm/src/rearrange_llm.py create mode 100644 spot_rl_experiments/spot_rl/models/__init__.py create mode 100644 spot_rl_experiments/spot_rl/models/owlvit.py create mode 100644 spot_rl_experiments/spot_rl/models/sentence_similarity.py create mode 100644 spot_rl_experiments/spot_rl/real_policy.py create mode 100644 spot_rl_experiments/spot_rl/ros_img_vis.py create mode 100644 spot_rl_experiments/spot_rl/spot_ros_node.py create mode 100644 spot_rl_experiments/spot_rl/utils/autodock.py create mode 100644 spot_rl_experiments/spot_rl/utils/depth_map_utils.py create mode 100644 spot_rl_experiments/spot_rl/utils/generate_place_goal.py create mode 100644 spot_rl_experiments/spot_rl/utils/helper_nodes.py create mode 100644 spot_rl_experiments/spot_rl/utils/img_publishers.py create mode 100644 spot_rl_experiments/spot_rl/utils/mask_rcnn_utils.py create mode 100644 spot_rl_experiments/spot_rl/utils/remote_spot.py create mode 100644 spot_rl_experiments/spot_rl/utils/remote_spot_listener.py create mode 100644 spot_rl_experiments/spot_rl/utils/robot_subscriber.py create mode 100644 spot_rl_experiments/spot_rl/utils/run_local_parallel_inference.py create mode 100644 spot_rl_experiments/spot_rl/utils/run_parallel_inference.py create mode 100644 spot_rl_experiments/spot_rl/utils/spot_rl_launch_local.py create mode 100644 spot_rl_experiments/spot_rl/utils/stopwatch.py create mode 100644 spot_rl_experiments/spot_rl/utils/utils.py create mode 100644 spot_rl_experiments/spot_rl/utils/waypoint_recorder.py create mode 100644 spot_rl_experiments/spot_rl/utils/whisper_translator.py create mode 160000 third_party/DeblurGANv2 create mode 160000 third_party/habitat-lab create mode 160000 third_party/mask_rcnn_detectron2 diff --git a/.circleci/config.yml b/.circleci/config.yml new file mode 100644 index 00000000..639760af --- /dev/null +++ b/.circleci/config.yml @@ -0,0 +1,29 @@ +version: 2.1 + +gpu: &gpu + machine: + image: ubuntu-2004-cuda-11.4:202110-01 + resource_class: gpu.nvidia.medium + environment: + FPS_THRESHOLD: 900 + +jobs: + pre-commit: + working_directory: ~/spot-sim2real + resource_class: small + docker: + - image: cimg/python:3.8 + steps: + - checkout + - run: + name: Running precommit checks + command: | + mkdir .mypy_cache + pip install pre-commit==3.1.1 + pre-commit install-hooks + pre-commit run --all-files + +workflows: + main: + jobs: + - pre-commit diff --git a/.flake8 b/.flake8 new file mode 100644 index 00000000..4ea1f4a4 --- /dev/null +++ b/.flake8 @@ -0,0 +1,12 @@ +[flake8] +ignore = E203, E266, E501, W503, F403, F401 +max-line-length = 89 +max-complexity = 25 +select = B,C,E,F,W,T4,B9 + +# Ignored rules +# E203: whitespace +# E266: too many ### +# E501: line too long +# F403: imported but not used +# F401: line breaks before binary operators diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..a769c23d --- /dev/null +++ b/.gitignore @@ -0,0 +1,9 @@ +**/__pycache__/** +.vscode + +bd_spot_wrapper/spot_wrapper.egg-info/** +bd_spot_wrapper/spot_wrapper/home.txt +spot_rl_experiments/spot_rl.egg-info/** +spot_rl_experiments/configs/waypoints.yaml + +**/grasp_visualizations/** diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..cec95f5d --- /dev/null +++ b/.gitmodules @@ -0,0 +1,12 @@ +[submodule "third_party/habitat-lab"] + path = third_party/habitat-lab + url = git@github.com:naokiyokoyama/habitat-lab.git + branch = pvp +[submodule "third_party/DeblurGANv2"] + path = third_party/DeblurGANv2 + url = git@github.com:naokiyokoyama/DeblurGANv2.git + branch = master +[submodule "third_party/mask_rcnn_detectron2"] + path = third_party/mask_rcnn_detectron2 + url = https://github.com/naokiyokoyama/mask_rcnn_detectron2 + branch = main diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..08d06377 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,46 @@ +repos: + - repo: https://github.com/psf/black + rev: 22.3.0 + hooks: + - id: black + exclude: | + (?x)^( + mask_rcnn_detectron2/ + | habitat-lab/ + | DeblurGANv2/ + ) + + - repo: https://github.com/pycqa/isort + rev: 5.11.5 + hooks: + - id: isort + args: ["--profile", "black", "--filter-files"] + exclude: | + (?x)^( + mask_rcnn_detectron2/ + | habitat-lab/ + | DeblurGANv2/ + ) + + - repo: https://github.com/pycqa/flake8 + rev: 5.0.4 + hooks: + - id: flake8 + exclude: | + (?x)^( + mask_rcnn_detectron2/ + | habitat-lab/ + | DeblurGANv2/ + ) + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v0.981 + hooks: + - id: mypy + args: [--install-types, --non-interactive, --no-strict-optional, --ignore-missing-imports] + exclude: | + (?x)^( + mask_rcnn_detectron2/ + | habitat-lab/ + | DeblurGANv2/ + ) \ No newline at end of file diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 00000000..08b500a2 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,80 @@ +# Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to make participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or + advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic + address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies within all project spaces, and it also applies when +an individual is representing the project or its community in public spaces. +Examples of representing a project or community include using an official +project e-mail address, posting via an official social media account, or acting +as an appointed representative at an online or offline event. Representation of +a project may be further defined and clarified by project maintainers. + +This Code of Conduct also applies outside the project spaces when there is a +reasonable belief that an individual's behavior may have a negative impact on +the project or its community. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at . All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 00000000..e16c3859 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,31 @@ +# Contributing to spot-sim2real +We want to make contributing to this project as easy and transparent as +possible. + +## Pull Requests +We actively welcome your pull requests. + +1. Fork the repo and create your branch from `main`. +2. If you've added code that should be tested, add tests. +3. If you've changed APIs, update the documentation. +4. Ensure the test suite passes. +5. Make sure your code lints. +6. If you haven't already, complete the Contributor License Agreement ("CLA"). + +## Contributor License Agreement ("CLA") +In order to accept your pull request, we need you to submit a CLA. You only need +to do this once to work on any of Facebook's open source projects. + +Complete your CLA here: + +## Issues +We use GitHub issues to track public bugs. Please ensure your description is +clear and has sufficient instructions to be able to reproduce the issue. + +Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe +disclosure of security bugs. In those cases, please go through the process +outlined on that page and do not file a public issue. + +## License +By contributing to spot-sim2real, you agree that your contributions will be licensed +under the LICENSE file in the root directory of this source tree. \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..1b277b98 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) Meta Platforms, Inc. and its affiliates. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 00000000..f93ab5dd --- /dev/null +++ b/README.md @@ -0,0 +1,207 @@ +# :robot: Spot-Sim2Real +Spot-Sim2Real is a modular library for development of Spot for embodied AI tasks (e.g., [Language-guided Skill Coordination (LSC)](https://languageguidedskillcoordination.github.io/), [Adaptive Skill Coordination (ASC)](https://arxiv.org/pdf/2304.00410.pdf)) -- configuring Spot robots, controlling sensorimotor skills, and coordinating Large Language Models (LLMs). + +## :memo: Setup instructions +Please refer to the [setup instructions page](/installation/SETUP_INSTRUCTIONS.md) for information on how to setup the repo. + +## :computer: Connecting to the robot +Computer can be connected to the robot in one of the following modes. +1. Ethernet (Gives best network speed, but it is cluttery :sad: )\ +This mode can be used to create a wired connection with the robot. Useful for teleoperating the robot via computer +2. Access Point Mode\ +This is a wireless mode where robot creates its wifi network. Connect robot to this mode for teleoperating it using controller over long distances. Robot is in Access Point mode if you see a wifi with name like `spot-BD-***********` (where * is a number) +3. Client Mode (Gives 2nd best network speed, we usually prefer this)\ +This is a wireless mode where robot is connected to an external wifi network (from a nearby router). Computer should be connected to this same network, wired connection between router and computer will be faster than wireless connection. + +**Follow the steps from [Spot's Network Setup](https://support.bostondynamics.com/s/article/Spot-network-setup) page by Boston Dynamics to connect to the robot.** + +After setting up spot in correct network configuration, please add its IP inside bashrc +```bash +echo "export SPOT_IP=" >> ~/.bashrc +source ~/.bashrc +``` + +Test and ensure you can ping spot +```bash +ping $SPOT_IP +``` + +If you get response like this, then you are on right network +```bash +(spot_ros) user@linux-machine:~$ ping $SPOT_IP +PING 192.168.1.5 (192.168.1.5) 56(84) bytes of data. +64 bytes from 192.168.1.5: icmp_seq=1 ttl=64 time=8.87 ms +64 bytes from 192.168.1.5: icmp_seq=2 ttl=64 time=7.36 ms +``` + +Before starting to run the code, you need to ensure that all ROS env variables are setup properly inside bashrc. Please follow the steps from [Setting ROS env variables](/installation/SETUP_INSTRUCTIONS.md#setting-ros-env-variables) for proper ROS env var setup. + +## :desktop_computer: Getting to the repo +Go to the repository +```bash +cd /path/to/spot-sim2real/ +``` + +The code for the demo lies inside the `main` branch. +```bash +# Check your current git branch +git rev-parse --abbrev-ref HEAD + +# If you are not in the `main` branch, then checkout to the `main` branch +git checkout main +``` + +## :light_rail: Try teleoperating the robot using keyboard +### :rotating_light: Running Emergency Stop +* Since we do not have a physical emergency stop button (like the large red push buttons), we need to run an e-stop node. + ```bash + python -m spot_wrapper.estop + ``` + +- Keep this window open at all the times, if the robot starts misbehaving you should be able to quickly press `s` or `space_bar` to kill the robot + +### :musical_keyboard: Running keyboard teleop +* Ensure you have the Estop up and running in one terminal. [Follow these instructions for e-stop](/README.md#rotating_light-running-emergency-stop) +* Run keyboard teleop with this command in a new terminal + ```bash + spot_keyboard_teleop + ``` + +## :video_game: Instructions to record waypoints (use joystick to move robot around) +- Before running scripts on the robot, waypoints should be recorded. These waypoints exist inside file `spot-sim2real/spot_rl_experiments/configs/waypoints.yaml` + +- Before recording receptacles, make the robot sit at home position then run following command + ```bash + spot_reset_home + ``` + +- There are 2 types of waypoints that one can record, + 1. clutter - These only require the (x, y, theta) of the target receptacle + 2. place - These requre (x, y, theta) for target receptable as well as (x, y, z) for exact drop location on the receptacle + +- To record a clutter target, teleoperate the robot to reach near the receptacle target (using joystick). Once robot is at a close distance to receptacle, run the following command + ```bash + spot_rl_waypoint_recorder -c + ``` + +- To record a place target, teleoperate the robot to reach near the receptacle target (using joystick). Once robot is at a close distance to receptacle, use manipulation mode in the joystick to manipulate the end-effector at desired (x,y,z) position. Once you are satisfied with the end-effector position, run the following command + ```bash + spot_rl_waypoint_recorder -p + ``` + + +## :rocket: Running instructions +### Running the demo (ASC/LSC/Seq-Experts) +#### Step1. Run the local launch executable +- In a new terminal, run the executable as + ```bash + spot_rl_launch_local + ``` + This command starts 4 tmux sessions\n + + 1. roscore + 2. img_publishers + 3. proprioception + 4. tts + +- You can run `tmux ls` in the terminal to ensure that all 4 tmux sessions are running. + You need to ensure that all 4 sessions remain active until 70 seconds after running the `spot_rl_launch_local`. If anyone of them dies before 70 seconds, it means there is some issue and you should rerun `spot_rl_launch_local`. + +- You should try re-running `spot_rl_launch_local` atleast 2-3 times to see if the issue still persists. Many times roscore takes a while to start due to which other nodes die, re-running can fix this issue. + +- You can verify if all ros nodes are up and running as expected if the output of `rostopic list` looks like the following + ```bash + (spot_ros) user@linux-machine:~$ rostopic list + /filtered_hand_depth + /filtered_head_depth + /hand_rgb + /mask_rcnn_detections + /mask_rcnn_visualizations + /raw_hand_depth + /raw_head_depth + /robot_state + /rosout + /rosout_agg + /text_to_speech + ``` +- If you don't get the output as follows, one of the tmux sessions might be failing. Follow [the debugging strategies](/installation/ISSUES.md#debugging-strategies-for-spot_rl_launch_local-if-any-one-of-the-4-sessions-are-dying-before-70-seconds) described in ISSUES.md for triaging and resolving these errors. + +#### Step2. Run ROS image visualization +- This is the image visualization tool that helps to understand what robot is seeing and perceiving from the world + ```bash + spot_rl_ros_img_vis + ``` +- Running this command will open an image viewer and start printing image frequency from different rosotopics. + +- If the image frequency corresponding to `mask_rcnn_visualizations` is too large and constant (like below), it means that the bounding box detector has not been fully initialized yet + ```bash + raw_head_depth: 9.33 raw_hand_depth: 9.33 hand_rgb: 9.33 filtered_head_depth: 11.20 filtered_hand_depth: 11.20 mask_rcnn_visualizations: 11.20 + raw_head_depth: 9.33 raw_hand_depth: 9.33 hand_rgb: 9.33 filtered_head_depth: 11.20 filtered_hand_depth: 8.57 mask_rcnn_visualizations: 11.20 + raw_head_depth: 9.33 raw_hand_depth: 9.33 hand_rgb: 9.33 filtered_head_depth: 8.34 filtered_hand_depth: 8.57 mask_rcnn_visualizations: 11.20 + ``` + + Once the `mask_rcnn_visualizations` start becoming dynamic (like below), you can proceed with next steps + ```bash + raw_head_depth: 6.87 raw_hand_depth: 6.88 hand_rgb: 6.86 filtered_head_depth: 4.77 filtered_hand_depth: 5.01 mask_rcnn_visualizations: 6.14 + raw_head_depth: 6.87 raw_hand_depth: 6.88 hand_rgb: 6.86 filtered_head_depth: 4.77 filtered_hand_depth: 5.01 mask_rcnn_visualizations: 5.33 + raw_head_depth: 4.14 raw_hand_depth: 4.15 hand_rgb: 4.13 filtered_head_depth: 4.15 filtered_hand_depth: 4.12 mask_rcnn_visualizations: 4.03 + raw_head_depth: 4.11 raw_hand_depth: 4.12 hand_rgb: 4.10 filtered_head_depth: 4.15 filtered_hand_depth: 4.12 mask_rcnn_visualizations: 4.03 + ``` + +#### Step3. Reset home **in a new terminal** +- This is an important step. Ensure robot is at its start location and sitting, then run the following command in a new terminal + ```bash + spot_reset_home + ``` + +- The waypoints that were recorded are w.r.t the home location. Since the odometry drifts while robot is moving, **it is necessary to reset home before start of every new run** + +#### Step4. Emergency stop +- Follow the steps described in [e-stop section](/README.md#rotating_light-running-emergency-stop) + + +#### Step5. Main demo code **in a new terminal** +- Ensure you have correctly added the waypoints of interest by following the [intructions to record waypoints](/README.md#rocket-running-instructions) +- In a new terminal you can now run the code of your choice + 1. To run Sequencial experts + ```bash + spot_rl_mobile_manipulation_env + ``` + + 2. To run Adaptive skill coordination + ```bash + spot_rl_mobile_manipulation_env -m + ``` + + 3. To run Language instructions with Sequencial experts, *ensure the usb microphone is connected to the computer* + ```bash + python spot_rl_experiments/spot_rl/envs/lang_env.py + ``` + + +- If you are done with demo of one of the above code and want to run another code, you do not need to re-run other sessions and nodes. Running a new command in the same terminal will work just fine. But **make sure to bring robot at home location and reset its home** using `spot_reset_home` in the same terminal + +## :mega: Acknowledgement +We thank [Naoki Yokoyama](http://naoki.io/) for setting up the foundation of the codebase, and [Joanne Truong](https://www.joannetruong.com/) for polishing the codebase. Spot-Sim2Real is built upon Naoki's codebases: [bd_spot_wrapper](https://github.com/naokiyokoyama/bd_spot_wrapper) and [spot_rl_experiments +](https://github.com/naokiyokoyama/spot_rl_experiments), and with new features (LLMs, pytest) and improving robustness. + + +## :writing_hand: Citations +If you find this repository helpful, feel free to cite our papers: [Adaptive Skill Coordination (ASC)](https://arxiv.org/pdf/2304.00410.pdf) and [Language-guided Skill Coordination (LSC)](https://languageguidedskillcoordination.github.io/). +``` +@article{yokoyama2023adaptive, + title={Adaptive Skill Coordination for Robotic Mobile Manipulation}, + author={Yokoyama, Naoki and Clegg, Alexander William and Truong, Joanne and Undersander, Eric and Yang, Tsung-Yen and Arnaud, Sergio and Ha, Sehoon and Batra, Dhruv and Rai, Akshara}, + journal={arXiv preprint arXiv:2304.00410}, + year={2023} +} + +@misc{yang2023adaptive, + title={LSC: Language-guided Skill Coordination for Open-Vocabulary Mobile Pick-and-Place}, + author={Yang, Tsung-Yen and Arnaud, Sergio and Shah, Kavit and Yokoyama, Naoki and Clegg, Alexander William and Truong, Joanne and Undersander, Eric and Maksymets, Oleksandr and Ha, Sehoon and Kalakrishnan, Mrinal and Mottaghi, Roozbeh and Batra, Dhruv and Rai, Akshara}, + howpublished={\url{https://languageguidedskillcoordination.github.io/}} +} +``` + +## License +Spot-Sim2Real is MIT licensed. See the [LICENSE file](/LICENSE) for details. \ No newline at end of file diff --git a/bd_spot_wrapper/.gitignore b/bd_spot_wrapper/.gitignore new file mode 100644 index 00000000..aad79e3b --- /dev/null +++ b/bd_spot_wrapper/.gitignore @@ -0,0 +1,3 @@ +spot_wrapper.egg-info/ +__pycache__ +spot_wrapper/home.txt diff --git a/bd_spot_wrapper/README.md b/bd_spot_wrapper/README.md new file mode 100644 index 00000000..6fa951c2 --- /dev/null +++ b/bd_spot_wrapper/README.md @@ -0,0 +1,40 @@ +# Simple Python API for Spot + +## Installation + +Create the conda env: + +```bash +conda create -n spot_env -y python=3.6 +conda activate spot_env +``` +Install requirements +```bash +pip install -r requirements.txt +``` +Install this package +```bash +# Make sure you are in the root of this repo +pip install -e . +``` + +## Quickstart +Ensure that you are connected to the robot's WiFi. + +The following script allows you to move the robot without having to use tablet (which prompts you to enter a password once a day): +``` +python -m spot_wrapper.keyboard_teleop +``` +If you get an error about the e-stop, you just need to make sure that you run this script in another terminal: +``` +python -m spot_wrapper.estop +``` + +Read through the `spot_wrapper/keyboard_teleop.py` to see most of what this repo offers in terms of actuating the Spot and its arm. + +To receive/monitor data (vision/proprioception) from the robot, you can use these scripts: +``` +python -m spot_wrapper.view_camera +python -m spot_wrapper.view_arm_proprioception +python -m spot_wrapper.monitor_nav_pose +``` diff --git a/bd_spot_wrapper/__init__.py b/bd_spot_wrapper/__init__.py new file mode 100644 index 00000000..5ce8caf5 --- /dev/null +++ b/bd_spot_wrapper/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/bd_spot_wrapper/data/depth_transforms.txt b/bd_spot_wrapper/data/depth_transforms.txt new file mode 100644 index 00000000..0b2e1bd1 --- /dev/null +++ b/bd_spot_wrapper/data/depth_transforms.txt @@ -0,0 +1,11 @@ +SpotCamIds.FRONTLEFT_DEPTH +x: 0.41619532493078804 +y: 0.03740343144695029 +z: 0.023127331893806183 +W: 0.5178 X: 0.1422 Y: 0.8133 Z: -0.2241 + +SpotCamIds.FRONTRIGHT_DEPTH +x: 0.4164822634134684 +y: -0.03614789234067159 +z: 0.023188565383785213 +W: 0.5184 X: -0.1495 Y: 0.8101 Z: 0.2294 \ No newline at end of file diff --git a/bd_spot_wrapper/data/image_sources.txt b/bd_spot_wrapper/data/image_sources.txt new file mode 100644 index 00000000..d6bae4d3 --- /dev/null +++ b/bd_spot_wrapper/data/image_sources.txt @@ -0,0 +1,329 @@ +[name: "back_depth" +cols: 424 +rows: 240 +depth_scale: 999.0 +pinhole { + intrinsics { + focal_length { + x: 213.1068572998047 + y: 213.1068572998047 + } + principal_point { + x: 210.05804443359375 + y: 121.17723083496094 + } + } +} +image_type: IMAGE_TYPE_DEPTH +, name: "back_depth_in_visual_frame" +cols: 640 +rows: 480 +pinhole { + intrinsics { + focal_length { + x: 257.3891296386719 + y: 257.0436096191406 + } + principal_point { + x: 316.3995056152344 + y: 235.3182830810547 + } + } +} +image_type: IMAGE_TYPE_DEPTH +, name: "back_fisheye_image" +cols: 640 +rows: 480 +pinhole { + intrinsics { + focal_length { + x: 257.3891296386719 + y: 257.0436096191406 + } + principal_point { + x: 316.3995056152344 + y: 235.3182830810547 + } + } +} +image_type: IMAGE_TYPE_VISUAL +, name: "frontleft_depth" +cols: 424 +rows: 240 +depth_scale: 999.0 +pinhole { + intrinsics { + focal_length { + x: 214.35023498535156 + y: 214.35023498535156 + } + principal_point { + x: 206.67208862304688 + y: 120.26058959960938 + } + } +} +image_type: IMAGE_TYPE_DEPTH +, name: "frontleft_depth_in_visual_frame" +cols: 640 +rows: 480 +pinhole { + intrinsics { + focal_length { + x: 258.2867126464844 + y: 257.8219909667969 + } + principal_point { + x: 314.52911376953125 + y: 243.18246459960938 + } + } +} +image_type: IMAGE_TYPE_DEPTH +, name: "frontleft_fisheye_image" +cols: 640 +rows: 480 +pinhole { + intrinsics { + focal_length { + x: 258.2867126464844 + y: 257.8219909667969 + } + principal_point { + x: 314.52911376953125 + y: 243.18246459960938 + } + } +} +image_type: IMAGE_TYPE_VISUAL +, name: "frontright_depth" +cols: 424 +rows: 240 +depth_scale: 999.0 +pinhole { + intrinsics { + focal_length { + x: 215.2163543701172 + y: 215.2163543701172 + } + principal_point { + x: 213.766357421875 + y: 121.92640686035156 + } + } +} +image_type: IMAGE_TYPE_DEPTH +, name: "frontright_depth_in_visual_frame" +cols: 640 +rows: 480 +pinhole { + intrinsics { + focal_length { + x: 257.6131591796875 + y: 257.16253662109375 + } + principal_point { + x: 317.3861083984375 + y: 244.52821350097656 + } + } +} +image_type: IMAGE_TYPE_DEPTH +, name: "frontright_fisheye_image" +cols: 640 +rows: 480 +pinhole { + intrinsics { + focal_length { + x: 257.6131591796875 + y: 257.16253662109375 + } + principal_point { + x: 317.3861083984375 + y: 244.52821350097656 + } + } +} +image_type: IMAGE_TYPE_VISUAL +, name: "hand_color_image" +cols: 640 +rows: 480 +pinhole { + intrinsics { + focal_length { + x: 552.0291012161067 + y: 552.0291012161067 + } + principal_point { + x: 320.0 + y: 240.0 + } + } +} +image_type: IMAGE_TYPE_VISUAL +, name: "hand_color_in_hand_depth_frame" +cols: 640 +rows: 480 +pinhole { + intrinsics { + focal_length { + x: 552.0291012161067 + y: 552.0291012161067 + } + principal_point { + x: 320.0 + y: 240.0 + } + } +} +image_type: IMAGE_TYPE_VISUAL +, name: "hand_depth" +cols: 224 +rows: 171 +depth_scale: 1000.0 +pinhole { + intrinsics { + focal_length { + x: 210.4172821044922 + y: 210.4172821044922 + } + principal_point { + x: 105.02783203125 + y: 87.0968246459961 + } + } +} +image_type: IMAGE_TYPE_DEPTH +, name: "hand_depth_in_hand_color_frame" +cols: 224 +rows: 171 +depth_scale: 1000.0 +pinhole { + intrinsics { + focal_length { + x: 210.4172821044922 + y: 210.4172821044922 + } + principal_point { + x: 105.02783203125 + y: 87.0968246459961 + } + } +} +image_type: IMAGE_TYPE_DEPTH +, name: "hand_image" +cols: 224 +rows: 171 +depth_scale: 1000.0 +pinhole { + intrinsics { + focal_length { + x: 210.4172821044922 + y: 210.4172821044922 + } + principal_point { + x: 105.02783203125 + y: 87.0968246459961 + } + } +} +image_type: IMAGE_TYPE_VISUAL +, name: "left_depth" +cols: 424 +rows: 240 +depth_scale: 999.0 +pinhole { + intrinsics { + focal_length { + x: 223.71405029296875 + y: 223.71405029296875 + } + principal_point { + x: 226.58502197265625 + y: 123.37451171875 + } + } +} +image_type: IMAGE_TYPE_DEPTH +, name: "left_depth_in_visual_frame" +cols: 640 +rows: 480 +pinhole { + intrinsics { + focal_length { + x: 257.03656005859375 + y: 256.6549987792969 + } + principal_point { + x: 317.0937805175781 + y: 242.4779815673828 + } + } +} +image_type: IMAGE_TYPE_DEPTH +, name: "left_fisheye_image" +cols: 640 +rows: 480 +pinhole { + intrinsics { + focal_length { + x: 257.03656005859375 + y: 256.6549987792969 + } + principal_point { + x: 317.0937805175781 + y: 242.4779815673828 + } + } +} +image_type: IMAGE_TYPE_VISUAL +, name: "right_depth" +cols: 424 +rows: 240 +depth_scale: 999.0 +pinhole { + intrinsics { + focal_length { + x: 221.22935485839844 + y: 221.22935485839844 + } + principal_point { + x: 228.81666564941406 + y: 118.76554870605469 + } + } +} +image_type: IMAGE_TYPE_DEPTH +, name: "right_depth_in_visual_frame" +cols: 640 +rows: 480 +pinhole { + intrinsics { + focal_length { + x: 258.142822265625 + y: 257.6732177734375 + } + principal_point { + x: 317.4607849121094 + y: 248.84344482421875 + } + } +} +image_type: IMAGE_TYPE_DEPTH +, name: "right_fisheye_image" +cols: 640 +rows: 480 +pinhole { + intrinsics { + focal_length { + x: 258.142822265625 + y: 257.6732177734375 + } + principal_point { + x: 317.4607849121094 + y: 248.84344482421875 + } + } +} +image_type: IMAGE_TYPE_VISUAL +] \ No newline at end of file diff --git a/bd_spot_wrapper/data/transforms_example.txt b/bd_spot_wrapper/data/transforms_example.txt new file mode 100644 index 00000000..d10b7c4d --- /dev/null +++ b/bd_spot_wrapper/data/transforms_example.txt @@ -0,0 +1,89 @@ +{'gpe': parent_frame_name: "odom" +parent_tform_child { + position { + x: -1.0544086694717407 + y: -1.139175534248352 + z: -4.46881103515625 + } + rotation { + x: -0.002153972629457712 + y: 0.010553546249866486 + z: -2.2733371224603616e-05 + w: 0.9999419450759888 + } +} +, 'flat_body': parent_frame_name: "body" +parent_tform_child { + position { + } + rotation { + x: 0.001049818005412817 + y: 0.0017311159754171968 + z: 1.8363418803346576e-06 + w: 0.9999979734420776 + } +} +, 'hand': parent_frame_name: "body" +parent_tform_child { + position { + x: 0.5526498556137085 + y: 0.002247427823022008 + z: 0.26357635855674744 + } + rotation { + x: -0.0009961266769096255 + y: 0.0019155505578964949 + z: 0.005823923274874687 + w: 0.9999806880950928 + } +} +, 'odom': parent_frame_name: "body" +parent_tform_child { + position { + x: 1.564515471458435 + y: 0.07693791389465332 + z: 4.223989009857178 + } + rotation { + x: 0.00032411710708402097 + y: 0.0019984564278274775 + z: -0.3749612271785736 + w: 0.9270382523536682 + } +} +, 'vision': parent_frame_name: "body" +parent_tform_child { + position { + x: 1.4845632314682007 + y: -0.024931997060775757 + z: -0.1606001853942871 + } + rotation { + x: -0.0016495820600539446 + y: 0.0011737799504771829 + z: -0.9973124861717224 + w: 0.07323848456144333 + } +} +, 'link_wr1': parent_frame_name: "body" +parent_tform_child { + position { + x: 0.35709452629089355 + y: -2.9751361580565572e-05 + z: 0.26432785391807556 + } + rotation { + x: -0.0009961266769096255 + y: 0.0019155505578964949 + z: 0.005823923274874687 + w: 0.9999806880950928 + } +} +, 'body': parent_tform_child { + position { + } + rotation { + w: 1.0 + } +} +} diff --git a/bd_spot_wrapper/generate_executables.py b/bd_spot_wrapper/generate_executables.py new file mode 100644 index 00000000..85fa11e8 --- /dev/null +++ b/bd_spot_wrapper/generate_executables.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import os +import os.path as osp +import sys + +this_dir = osp.dirname(osp.abspath(__file__)) +base_dir = osp.join(this_dir, "spot_wrapper") +bin_dir = osp.join(os.environ["CONDA_PREFIX"], "bin") + +orig_to_alias = { + "estop": "spot_estop", + "headless_estop": "spot_headless_estop", + "home_robot": "spot_reset_home", + "keyboard_teleop": "spot_keyboard_teleop", + "monitor_nav_pose": "spot_monitor_nav_pose", + "roll_over": "spot_roll_over", + "selfright": "spot_selfright", + "sit": "spot_sit", + "stand": "spot_stand", + "view_arm_proprioception": "spot_view_arm_proprioception", + "view_camera": "spot_view_camera", +} + + +print("Generating executables...") +for orig, alias in orig_to_alias.items(): + exe_path = osp.join(bin_dir, alias) + data = f"#!/usr/bin/env bash \n{sys.executable} -m spot_wrapper.{orig} $@\n" + with open(exe_path, "w") as f: + f.write(data) + os.chmod(exe_path, 33277) + print("Added:", alias) +print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") +print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") +print("THESE EXECUTABLES ARE ONLY VISIBLE TO THE CURRENT CONDA ENV!!") +print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") +print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") diff --git a/bd_spot_wrapper/requirements.txt b/bd_spot_wrapper/requirements.txt new file mode 100644 index 00000000..c2aa96a9 --- /dev/null +++ b/bd_spot_wrapper/requirements.txt @@ -0,0 +1,3 @@ +bosdyn-client +bosdyn-api +opencv-python diff --git a/bd_spot_wrapper/setup.py b/bd_spot_wrapper/setup.py new file mode 100644 index 00000000..0b336637 --- /dev/null +++ b/bd_spot_wrapper/setup.py @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import setuptools + +setuptools.setup( + name="spot_wrapper", + version="0.1", + author="Naoki Yokoyama", + author_email="naokiyokoyama@github", + description="Python wrapper for Boston Dynamics Spot Robot", + url="https://github.com/naokiyokoyama/bd_spot_wrapper", + packages=setuptools.find_packages(), +) diff --git a/bd_spot_wrapper/spot_wrapper/__init__.py b/bd_spot_wrapper/spot_wrapper/__init__.py new file mode 100644 index 00000000..5ce8caf5 --- /dev/null +++ b/bd_spot_wrapper/spot_wrapper/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/bd_spot_wrapper/spot_wrapper/draw_square.py b/bd_spot_wrapper/spot_wrapper/draw_square.py new file mode 100644 index 00000000..69d8adf8 --- /dev/null +++ b/bd_spot_wrapper/spot_wrapper/draw_square.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import numpy as np +from spot_wrapper.spot import Spot + +SQUARE_CENTER = np.array([0.5, 0.0, 0.6]) +SQUARE_SIDE = 0.4 +GRIPPER_WAYPOINTS = [ + SQUARE_CENTER, + SQUARE_CENTER + np.array([0.0, 0.0, SQUARE_SIDE / 2]), + SQUARE_CENTER + np.array([0.0, SQUARE_SIDE / 2, SQUARE_SIDE / 2]), + SQUARE_CENTER + np.array([0.0, -SQUARE_SIDE / 2, SQUARE_SIDE / 2]), + SQUARE_CENTER + np.array([0.0, -SQUARE_SIDE / 2, -SQUARE_SIDE / 2]), + SQUARE_CENTER + np.array([0.0, SQUARE_SIDE / 2, -SQUARE_SIDE / 2]), + SQUARE_CENTER, +] + + +def main(spot: Spot): + spot.power_on() + spot.blocking_stand() + + # Open the gripper + spot.open_gripper() + + # Move arm to initial configuration + try: + for point in GRIPPER_WAYPOINTS: + spot.loginfo("TRAVELING TO WAYPOINT") + cmd_id = spot.move_gripper_to_point(point, [0.0, 0.0, 0.0]) + spot.block_until_arm_arrives(cmd_id, timeout_sec=10) + spot.loginfo("REACHED WAYPOINT") + finally: + spot.power_off() + + +if __name__ == "__main__": + spot = Spot("DrawSquare") + with spot.get_lease() as lease: + main(spot) diff --git a/bd_spot_wrapper/spot_wrapper/estop.py b/bd_spot_wrapper/spot_wrapper/estop.py new file mode 100644 index 00000000..01b7a765 --- /dev/null +++ b/bd_spot_wrapper/spot_wrapper/estop.py @@ -0,0 +1,233 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +# Copyright (c) 2021 Boston Dynamics, Inc. All rights reserved. +# +# Downloading, reproducing, distributing or otherwise using the SDK Software +# is subject to the terms and conditions of the Boston Dynamics Software +# Development Kit License (20191101-BDSDK-SL). + +"""Provides a programmatic estop to stop the robot.""" +from __future__ import print_function + +import argparse +import curses +import logging +import os +import signal +import sys +import time + +import bosdyn.client.util +from bosdyn.client.estop import EstopClient, EstopEndpoint, EstopKeepAlive +from bosdyn.client.robot_state import RobotStateClient + +try: + SPOT_ADMIN_PW = os.environ["SPOT_ADMIN_PW"] +except KeyError: + raise RuntimeError( + "\nSPOT_ADMIN_PW not found as an environment variable!\n" + "Please run:\n" + "echo 'export SPOT_ADMIN_PW='>> ~/.bashrc\n or for MacOS,\n" + "echo 'export SPOT_ADMIN_PW='>> ~/.bash_profile\n" + "Then:\nsource ~/.bashrc\nor\nsource ~/.bash_profile" + ) + +try: + SPOT_IP = os.environ["SPOT_IP"] +except KeyError: + raise RuntimeError( + "\nSPOT_IP not found as an environment variable!\n" + "Please run:\n" + "echo 'export SPOT_IP='>> ~/.bashrc\n or for MacOS,\n" + "echo 'export SPOT_IP='>> ~/.bash_profile\n" + "Then:\nsource ~/.bashrc\nor\nsource ~/.bash_profile" + ) + + +class EstopNoGui: + """Provides a software estop without a GUI. + + To use this estop, create an instance of the EstopNoGui class and use the stop() and allow() + functions programmatically. + """ + + def __init__(self, client, timeout_sec, name=None): + # Force server to set up a single endpoint system + ep = EstopEndpoint(client, name, timeout_sec) + ep.force_simple_setup() + + # Begin periodic check-in between keep-alive and robot + self.estop_keep_alive = EstopKeepAlive(ep) + + # Release the estop + self.estop_keep_alive.allow() + + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_val, exc_tb): + """Cleanly shut down estop on exit.""" + self.estop_keep_alive.end_periodic_check_in() + + def stop(self): + self.estop_keep_alive.stop() + + def allow(self): + self.estop_keep_alive.allow() + + def settle_then_cut(self): + self.estop_keep_alive.settle_then_cut() + + +def main(argv): + """If this file is the main, create an instance of EstopNoGui and wait for user to terminate. + + This has little practical use, because calling the function this way does not give the user + any way to trigger an estop from the terminal. + """ + parser = argparse.ArgumentParser() + bosdyn.client.util.add_common_arguments(parser) + parser.add_argument( + "-t", "--timeout", type=float, default=5, help="Timeout in seconds" + ) + options = parser.parse_args(argv) + + # Set username to admin and read PW from os variables + options.username = "admin" + options.password = SPOT_ADMIN_PW + + # Create robot object + sdk = bosdyn.client.create_standard_sdk("estop_nogui") + robot = sdk.create_robot(options.hostname) + robot.authenticate(options.username, options.password) + + # Create estop client for the robot + estop_client = robot.ensure_client(EstopClient.default_service_name) + + # Create nogui estop + estop_nogui = EstopNoGui(estop_client, options.timeout, "Estop NoGUI") + + # Create robot state client for the robot + state_client = robot.ensure_client(RobotStateClient.default_service_name) + + # Initialize curses screen display + stdscr = curses.initscr() + + def cleanup_example(msg): + """Shut down curses and exit the program.""" + print("Exiting") + # pylint: disable=unused-argument + estop_nogui.estop_keep_alive.shutdown() + + # Clean up and close curses + stdscr.keypad(False) + curses.echo() + stdscr.nodelay(False) + curses.endwin() + print(msg) + + def clean_exit(msg=""): + cleanup_example(msg) + exit(0) + + def sigint_handler(sig, frame): + """Exit the application on interrupt.""" + clean_exit() + + def run_example(): + """Run the actual example with the curses screen display""" + # Set up curses screen display to monitor for stop request + curses.noecho() + stdscr.keypad(True) + stdscr.nodelay(True) + curses.start_color() + curses.init_pair(1, curses.COLOR_GREEN, curses.COLOR_BLACK) + curses.init_pair(2, curses.COLOR_YELLOW, curses.COLOR_BLACK) + curses.init_pair(3, curses.COLOR_RED, curses.COLOR_BLACK) + # If terminal cannot handle colors, do not proceed + if not curses.has_colors(): + return + + # Curses eats Ctrl-C keyboard input, but keep a SIGINT handler around for + # explicit kill signals outside of the program. + signal.signal(signal.SIGINT, sigint_handler) + + # Clear screen + stdscr.clear() + + # Display usage instructions in terminal + stdscr.addstr("Estop w/o GUI running.\n") + stdscr.addstr("\n") + stdscr.addstr("[q] or [Ctrl-C]: Quit\n", curses.color_pair(2)) + stdscr.addstr("[SPACE]: Trigger estop\n", curses.color_pair(2)) + stdscr.addstr("[r]: Release estop\n", curses.color_pair(2)) + stdscr.addstr("[s]: Settle then cut estop\n", curses.color_pair(2)) + + # Monitor estop until user exits + while True: + # Retrieve user input (non-blocking) + c = stdscr.getch() + + try: + if c == ord(" "): + estop_nogui.stop() + if c == ord("r"): + estop_nogui.allow() + if c == ord("q") or c == 3: + clean_exit("Exit on user input") + if c == ord("s"): + estop_nogui.settle_then_cut() + # If the user attempts to toggle estop without valid endpoint + except bosdyn.client.estop.EndpointUnknownError: + clean_exit("This estop endpoint no longer valid. Exiting...") + + # Check if robot is estopped by any estops + estop_status = "NOT_STOPPED\n" + estop_status_color = curses.color_pair(1) + state = state_client.get_robot_state() + estop_states = state.estop_states + for estop_state in estop_states: + state_str = estop_state.State.Name(estop_state.state) + if state_str == "STATE_ESTOPPED": + estop_status = "STOPPED\n" + estop_status_color = curses.color_pair(3) + break + elif state_str == "STATE_UNKNOWN": + estop_status = "ERROR\n" + estop_status_color = curses.color_pair(3) + elif state_str == "STATE_NOT_ESTOPPED": + pass + else: + # Unknown estop status + clean_exit() + + # Display current estop status + if not estop_nogui.estop_keep_alive.status_queue.empty(): + latest_status = estop_nogui.estop_keep_alive.status_queue.get()[ + 1 + ].strip() + if latest_status != "": + # If you lose this estop endpoint, report it to user + stdscr.addstr(7, 0, latest_status, curses.color_pair(3)) + stdscr.addstr(6, 0, estop_status, estop_status_color) + + # Slow down loop + time.sleep(0.5) + + # Run all curses code in a try so we can cleanly exit if something goes wrong + try: + run_example() + except Exception as e: + cleanup_example(e) + raise e + + +if __name__ == "__main__": + # Open terminal interface and hold estop until user exits with SIGINT + if len(sys.argv) == 1: + sys.argv.append(SPOT_IP) + if not main(sys.argv[1:]): + sys.exit(1) diff --git a/bd_spot_wrapper/spot_wrapper/headless_estop.py b/bd_spot_wrapper/spot_wrapper/headless_estop.py new file mode 100644 index 00000000..3897f6f6 --- /dev/null +++ b/bd_spot_wrapper/spot_wrapper/headless_estop.py @@ -0,0 +1,112 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import re +import signal +import struct +import threading +import time + +from bosdyn.client.estop import EstopClient +from spot_wrapper.estop import EstopNoGui +from spot_wrapper.spot import Spot +from spot_wrapper.utils import say + +""" +To run without sudo, you need to run this command: +sudo gpasswd -a $USER input + +and then reboot. +""" + + +class MyKeyEventClass2(object): + def __init__(self): + self.done = False + signal.signal(signal.SIGINT, self.cleanup) + + with open("/proc/bus/input/devices") as f: + devices_file_contents = f.read() + + # Spot-related code + spot = Spot("HeadlessEstop") + estop_client = spot.robot.ensure_client(EstopClient.default_service_name) + self.estop_nogui = EstopNoGui(estop_client, 5, "Estop NoGUI") + say("Headless e-stopping program initialized") + + for handlers in re.findall( + r"""H: Handlers=([^\n]+)""", devices_file_contents, re.DOTALL + ): + dev_event_file = "/dev/input/event" + re.search( + r"event(\d+)", handlers + ).group(1) + if "kbd" in handlers: + t = threading.Thread( + target=self.read_events, kwargs={"dev_event_file": dev_event_file} + ) + t.daemon = True + t.start() + + while not self.done: # Wait for Ctrl+C + time.sleep(0.5) + + def cleanup(self, signum, frame): + self.done = True + + def read_events(self, dev_event_file): + print("Listening for kbd events on dev_event_file=" + str(dev_event_file)) + try: + of = open(dev_event_file, "rb") + except IOError as e: + if e.strerror == "Permission denied": + print( + "You don't have read permission on ({}). Are you root?".format( + dev_event_file + ) + ) + return + while True: + event_bin_format = ( + "llHHI" # See kernel documentation for 'struct input_event' + ) + # For details, read section 5 of this document: + # https://www.kernel.org/doc/Documentation/input/input.txt + data = of.read(struct.calcsize(event_bin_format)) + seconds, microseconds, e_type, code, value = struct.unpack( + event_bin_format, data + ) + full_time = seconds + microseconds / 1000000 + if e_type == 0x1: # 0x1 == EV_KEY means key press or release. + d = ( + "RELEASE" if value == 0 else "PRESS" + ) # value == 0 release, value == 1 press + print( + "Got key " + + d + + " from " + + str(dev_event_file) + + ": t=" + + str(full_time) + + "us type=" + + str(e_type) + + " code=" + + str(code) + ) + + # Spot-related code + if d == "PRESS": + if 0: # str(code) == "108": # down + self.estop_nogui.settle_then_cut() + say("Activating e-stop") + elif str(code) == "103": # up + self.estop_nogui.allow() + say("Releasing e-stop") + + +if __name__ == "__main__": + try: + a = MyKeyEventClass2() + finally: + say("Headless e-stopping program terminating.") diff --git a/bd_spot_wrapper/spot_wrapper/home_robot.py b/bd_spot_wrapper/spot_wrapper/home_robot.py new file mode 100644 index 00000000..ba89dd3e --- /dev/null +++ b/bd_spot_wrapper/spot_wrapper/home_robot.py @@ -0,0 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from spot_wrapper.spot import Spot + +if __name__ == "__main__": + spot = Spot("NavPoseMonitor") + spot.home_robot() diff --git a/bd_spot_wrapper/spot_wrapper/keyboard_teleop.py b/bd_spot_wrapper/spot_wrapper/keyboard_teleop.py new file mode 100644 index 00000000..02b09faa --- /dev/null +++ b/bd_spot_wrapper/spot_wrapper/keyboard_teleop.py @@ -0,0 +1,173 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +# mypy: ignore-errors +import curses +import os +import signal +import time + +import numpy as np +from spot_wrapper.spot import Spot, SpotCamIds + +MOVE_INCREMENT = 0.02 +TILT_INCREMENT = 5.0 +BASE_ANGULAR_VEL = np.deg2rad(50) +BASE_LIN_VEL = 0.75 +DOCK_ID = int(os.environ.get("SPOT_DOCK_ID", 520)) +UPDATE_PERIOD = 0.2 + +# Where the gripper goes to upon initialization +INITIAL_POINT = np.array([0.5, 0.0, 0.35]) +INITIAL_RPY = np.deg2rad([0.0, 45.0, 0.0]) +KEY2GRIPPERMOVEMENT = { + "w": np.array([0.0, 0.0, MOVE_INCREMENT, 0.0, 0.0, 0.0]), # move up + "s": np.array([0.0, 0.0, -MOVE_INCREMENT, 0.0, 0.0, 0.0]), # move down + "a": np.array([0.0, MOVE_INCREMENT, 0.0, 0.0, 0.0, 0.0]), # move left + "d": np.array([0.0, -MOVE_INCREMENT, 0.0, 0.0, 0.0, 0.0]), # move right + "q": np.array([MOVE_INCREMENT, 0.0, 0.0, 0.0, 0.0, 0.0]), # move forward + "e": np.array([-MOVE_INCREMENT, 0.0, 0.0, 0.0, 0.0, 0.0]), # move backward + "i": np.deg2rad([0.0, 0.0, 0.0, 0.0, -TILT_INCREMENT, 0.0]), # pitch up + "k": np.deg2rad([0.0, 0.0, 0.0, 0.0, TILT_INCREMENT, 0.0]), # pitch down + "j": np.deg2rad([0.0, 0.0, 0.0, 0.0, 0.0, TILT_INCREMENT]), # pan left + "l": np.deg2rad([0.0, 0.0, 0.0, 0.0, 0.0, -TILT_INCREMENT]), # pan right +} +KEY2BASEMOVEMENT = { + "q": [0.0, 0.0, BASE_ANGULAR_VEL], # turn left + "e": [0.0, 0.0, -BASE_ANGULAR_VEL], # turn right + "w": [BASE_LIN_VEL, 0.0, 0.0], # go forward + "s": [-BASE_LIN_VEL, 0.0, 0.0], # go backward + "a": [0.0, BASE_LIN_VEL, 0.0], # strafe left + "d": [0.0, -BASE_LIN_VEL, 0.0], # strafe right +} +INSTRUCTIONS = ( + "Use 'wasdqe' for translating gripper, 'ijkl' for rotating.\n" + "Use 'g' to grasp whatever is at the center of the gripper image.\n" + "Press 't' to toggle between controlling the arm or the base\n" + "('wasdqe' will control base).\n" + "Press 'z' to quit.\n" +) + + +def move_to_initial(spot): + point = INITIAL_POINT + rpy = INITIAL_RPY + cmd_id = spot.move_gripper_to_point(point, rpy) + spot.block_until_arm_arrives(cmd_id, timeout_sec=1.5) + cement_arm_joints(spot) + + return point, rpy + + +def cement_arm_joints(spot): + arm_proprioception = spot.get_arm_proprioception() + current_positions = np.array( + [v.position.value for v in arm_proprioception.values()] + ) + spot.set_arm_joint_positions(positions=current_positions, travel_time=UPDATE_PERIOD) + + +def raise_error(sig, frame): + raise RuntimeError + + +def main(spot: Spot): + """Uses IK to move the arm by setting hand poses""" + spot.power_on() + spot.blocking_stand() + + # Open the gripper + spot.open_gripper() + + # Move arm to initial configuration + point, rpy = move_to_initial(spot) + control_arm = False + + # Start in-terminal GUI + stdscr = curses.initscr() + stdscr.nodelay(True) + curses.noecho() + signal.signal(signal.SIGINT, raise_error) + stdscr.addstr(INSTRUCTIONS) + last_execution = time.time() + try: + while True: + point_rpy = np.concatenate([point, rpy]) + pressed_key = stdscr.getch() + + key_not_applicable = False + + # Don't update if no key was pressed or we updated too recently + if pressed_key == -1 or time.time() - last_execution < UPDATE_PERIOD: + continue + + pressed_key = chr(pressed_key) + + if pressed_key == "z": + # Quit + break + elif pressed_key == "t": + # Toggle between controlling arm or base + control_arm = not control_arm + if not control_arm: + cement_arm_joints(spot) + spot.loginfo(f"control_arm: {control_arm}") + time.sleep(0.2) # Wait before we starting listening again + elif pressed_key == "g": + # Grab whatever object is at the center of hand RGB camera image + image_responses = spot.get_image_responses([SpotCamIds.HAND_COLOR]) + hand_image_response = image_responses[0] # only expecting one image + spot.grasp_point_in_image(hand_image_response) + # Retract arm back to initial configuration + point, rpy = move_to_initial(spot) + elif pressed_key == "r": + # Open gripper + spot.open_gripper() + elif pressed_key == "n": + try: + spot.dock(DOCK_ID) + spot.home_robot() + except Exception: + print("Dock was not found!") + elif pressed_key == "i": + point, rpy = move_to_initial(spot) + else: + # Tele-operate either the gripper pose or the base + if control_arm: + if pressed_key in KEY2GRIPPERMOVEMENT: + # Move gripper + point_rpy += KEY2GRIPPERMOVEMENT[pressed_key] + point, rpy = point_rpy[:3], point_rpy[3:] + cmd_id = spot.move_gripper_to_point(point, rpy) + print("Gripper destination: ", point, rpy) + spot.block_until_arm_arrives( + cmd_id, timeout_sec=UPDATE_PERIOD * 0.5 + ) + elif pressed_key in KEY2BASEMOVEMENT: + # Move base + x_vel, y_vel, ang_vel = KEY2BASEMOVEMENT[pressed_key] + spot.set_base_velocity( + x_vel=x_vel, + y_vel=y_vel, + ang_vel=ang_vel, + vel_time=UPDATE_PERIOD * 2, + ) + else: + key_not_applicable = True + + if not key_not_applicable: + last_execution = time.time() + + finally: + spot.power_off() + curses.echo() + stdscr.nodelay(False) + curses.endwin() + + +if __name__ == "__main__": + spot = Spot("ArmKeyboardTeleop") + with spot.get_lease(hijack=True) as lease: + main(spot) diff --git a/bd_spot_wrapper/spot_wrapper/monitor_nav_pose.py b/bd_spot_wrapper/spot_wrapper/monitor_nav_pose.py new file mode 100644 index 00000000..29261685 --- /dev/null +++ b/bd_spot_wrapper/spot_wrapper/monitor_nav_pose.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import time + +import numpy as np +from spot_wrapper.spot import Spot + + +def main(spot: Spot): + while True: + x, y, yaw = spot.get_xy_yaw() + spot.loginfo(f"x: {x}, y: {y}, yaw: {np.rad2deg(yaw)}") + time.sleep(1 / 30.0) + + +if __name__ == "__main__": + spot = Spot("NavPoseMonitor") + main(spot) diff --git a/bd_spot_wrapper/spot_wrapper/roll_over.py b/bd_spot_wrapper/spot_wrapper/roll_over.py new file mode 100644 index 00000000..801c481d --- /dev/null +++ b/bd_spot_wrapper/spot_wrapper/roll_over.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import time + +from spot_wrapper.spot import Spot + + +def main(spot: Spot): + """Make Spot stand""" + spot.power_on() + spot.roll_over() + + # Wait 5 seconds to before powering down... + while True: + pass + time.sleep(5) + spot.power_off() + + +if __name__ == "__main__": + spot = Spot("RolloverClient") + with spot.get_lease() as lease: + main(spot) diff --git a/bd_spot_wrapper/spot_wrapper/selfright.py b/bd_spot_wrapper/spot_wrapper/selfright.py new file mode 100644 index 00000000..24314acf --- /dev/null +++ b/bd_spot_wrapper/spot_wrapper/selfright.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import time + +from spot_wrapper.spot import Spot + + +def main(spot: Spot): + """Make Spot stand""" + spot.power_on() + spot.blocking_selfright() + + # Wait 3 seconds to before powering down... + while True: + pass + time.sleep(3) + spot.power_off() + + +if __name__ == "__main__": + spot = Spot("BasicSelfRightClient") + with spot.get_lease() as lease: + main(spot) diff --git a/bd_spot_wrapper/spot_wrapper/sit.py b/bd_spot_wrapper/spot_wrapper/sit.py new file mode 100644 index 00000000..0b4e0c88 --- /dev/null +++ b/bd_spot_wrapper/spot_wrapper/sit.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import time + +from spot_wrapper.spot import Spot + + +def main(spot: Spot): + """Make Spot stand""" + spot.power_on() + spot.sit() + + # Wait 5 seconds to before powering down... + while True: + pass + time.sleep(5) + spot.power_off() + + +if __name__ == "__main__": + spot = Spot("RolloverClient") + with spot.get_lease() as lease: + main(spot) diff --git a/bd_spot_wrapper/spot_wrapper/spot.py b/bd_spot_wrapper/spot_wrapper/spot.py new file mode 100644 index 00000000..39bf3ef3 --- /dev/null +++ b/bd_spot_wrapper/spot_wrapper/spot.py @@ -0,0 +1,787 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +# mypy: ignore-errors +# Copyright (c) 2021 Boston Dynamics, Inc. All rights reserved. +# +# Downloading, reproducing, distributing or otherwise using the SDK Software +# is subject to the terms and conditions of the Boston Dynamics Software +# Development Kit License (20191101-BDSDK-SL). + +""" Easy-to-use wrapper for properly controlling Spot """ +import os +import os.path as osp +import time +from collections import OrderedDict + +import bosdyn.client +import bosdyn.client.lease +import bosdyn.client.util +import cv2 +import numpy as np +from bosdyn import geometry +from bosdyn.api import ( + arm_command_pb2, + basic_command_pb2, + geometry_pb2, + image_pb2, + manipulation_api_pb2, + robot_command_pb2, + synchronized_command_pb2, + trajectory_pb2, +) +from bosdyn.api.geometry_pb2 import SE2Velocity, SE2VelocityLimit, Vec2 +from bosdyn.api.spot import robot_command_pb2 as spot_command_pb2 +from bosdyn.client import math_helpers +from bosdyn.client.docking import blocking_dock_robot, blocking_undock +from bosdyn.client.frame_helpers import ( + GRAV_ALIGNED_BODY_FRAME_NAME, + HAND_FRAME_NAME, + VISION_FRAME_NAME, + get_vision_tform_body, +) +from bosdyn.client.image import ImageClient, build_image_request +from bosdyn.client.manipulation_api_client import ManipulationApiClient +from bosdyn.client.robot_command import ( + RobotCommandBuilder, + RobotCommandClient, + block_until_arm_arrives, + blocking_selfright, + blocking_stand, +) +from bosdyn.client.robot_state import RobotStateClient +from google.protobuf import wrappers_pb2 + +# Get Spot password and IP address +env_err_msg = ( + "\n{var_name} not found as an environment variable!\n" + "Please run:\n" + "echo 'export {var_name}=' >> ~/.bashrc\nor for MacOS,\n" + "echo 'export {var_name}=' >> ~/.bash_profile\n" + "Then:\nsource ~/.bashrc\nor\nsource ~/.bash_profile" +) +try: + SPOT_ADMIN_PW = os.environ["SPOT_ADMIN_PW"] +except KeyError: + raise RuntimeError(env_err_msg.format(var_name="SPOT_ADMIN_PW")) +try: + SPOT_IP = os.environ["SPOT_IP"] +except KeyError: + raise RuntimeError(env_err_msg.format(var_name="SPOT_IP")) + +ARM_6DOF_NAMES = [ + "arm0.sh0", + "arm0.sh1", + "arm0.el0", + "arm0.el1", + "arm0.wr0", + "arm0.wr1", +] + +HOME_TXT = osp.join(osp.dirname(osp.abspath(__file__)), "home.txt") + + +class SpotCamIds: + r"""Enumeration of types of cameras.""" + + BACK_DEPTH = "back_depth" + BACK_DEPTH_IN_VISUAL_FRAME = "back_depth_in_visual_frame" + BACK_FISHEYE = "back_fisheye_image" + FRONTLEFT_DEPTH = "frontleft_depth" + FRONTLEFT_DEPTH_IN_VISUAL_FRAME = "frontleft_depth_in_visual_frame" + FRONTLEFT_FISHEYE = "frontleft_fisheye_image" + FRONTRIGHT_DEPTH = "frontright_depth" + FRONTRIGHT_DEPTH_IN_VISUAL_FRAME = "frontright_depth_in_visual_frame" + FRONTRIGHT_FISHEYE = "frontright_fisheye_image" + HAND_COLOR = "hand_color_image" + HAND_COLOR_IN_HAND_DEPTH_FRAME = "hand_color_in_hand_depth_frame" + HAND_DEPTH = "hand_depth" + HAND_DEPTH_IN_HAND_COLOR_FRAME = "hand_depth_in_hand_color_frame" + HAND = "hand_image" + LEFT_DEPTH = "left_depth" + LEFT_DEPTH_IN_VISUAL_FRAME = "left_depth_in_visual_frame" + LEFT_FISHEYE = "left_fisheye_image" + RIGHT_DEPTH = "right_depth" + RIGHT_DEPTH_IN_VISUAL_FRAME = "right_depth_in_visual_frame" + RIGHT_FISHEYE = "right_fisheye_image" + + +# CamIds that need to be rotated by 270 degrees in order to appear upright +SHOULD_ROTATE = [ + SpotCamIds.FRONTLEFT_DEPTH, + SpotCamIds.FRONTRIGHT_DEPTH, + SpotCamIds.HAND_DEPTH, + SpotCamIds.HAND, +] + + +class Spot: + def __init__(self, client_name_prefix): + bosdyn.client.util.setup_logging() + sdk = bosdyn.client.create_standard_sdk(client_name_prefix) + robot = sdk.create_robot(SPOT_IP) + robot.authenticate("admin", SPOT_ADMIN_PW) + robot.time_sync.wait_for_sync() + self.robot = robot + self.command_client = None + self.spot_lease = None + + # Get clients + self.command_client = robot.ensure_client( + RobotCommandClient.default_service_name + ) + self.image_client = robot.ensure_client(ImageClient.default_service_name) + self.manipulation_api_client = robot.ensure_client( + ManipulationApiClient.default_service_name + ) + self.robot_state_client = robot.ensure_client( + RobotStateClient.default_service_name + ) + + # Used to re-center origin of global frame + if osp.isfile(HOME_TXT): + with open(HOME_TXT) as f: + data = f.read() + self.global_T_home = np.array([float(d) for d in data.split(", ")[:9]]) + self.global_T_home = self.global_T_home.reshape([3, 3]) + self.robot_recenter_yaw = float(data.split(", ")[-1]) + else: + self.global_T_home = None + self.robot_recenter_yaw = None + + # Print the battery charge level of the robot + self.loginfo(f"Current battery charge: {self.get_battery_charge()}%") + + def get_lease(self, hijack=False): + # Make sure a lease for this client isn't already active + assert self.spot_lease is None + self.spot_lease = SpotLease(self, hijack=hijack) + return self.spot_lease + + def get_cmd_feedback(self, cmd_id): + return self.command_client.robot_command_feedback(cmd_id) + + def is_estopped(self): + return self.robot.is_estopped() + + def power_on(self, timeout_sec=20): + self.robot.power_on(timeout_sec=timeout_sec) + assert self.robot.is_powered_on(), "Robot power on failed." + self.loginfo("Robot powered on.") + + def power_off(self, cut_immediately=False, timeout_sec=20): + self.loginfo("Powering robot off...") + self.robot.power_off(cut_immediately=cut_immediately, timeout_sec=timeout_sec) + assert not self.robot.is_powered_on(), "Robot power off failed." + self.loginfo("Robot safely powered off.") + + def blocking_stand(self, timeout_sec=10): + self.loginfo("Commanding robot to stand (blocking)...") + blocking_stand(self.command_client, timeout_sec=timeout_sec) + self.loginfo("Robot standing.") + + def stand(self, timeout_sec=10): + stand_command = RobotCommandBuilder.synchro_stand_command() + cmd_id = self.command_client.robot_command(stand_command, timeout=timeout_sec) + return cmd_id + + def blocking_selfright(self, timeout_sec=20): + self.loginfo("Commanding robot to self-right (blocking)...") + blocking_selfright(self.command_client, timeout_sec=timeout_sec) + self.loginfo("Robot has self-righted.") + + def loginfo(self, *args, **kwargs): + self.robot.logger.info(*args, **kwargs) + + def open_gripper(self): + """Does not block, be careful!""" + gripper_command = RobotCommandBuilder.claw_gripper_open_command() + self.command_client.robot_command(gripper_command) + + def close_gripper(self): + """Does not block, be careful!""" + gripper_command = RobotCommandBuilder.claw_gripper_close_command() + self.command_client.robot_command(gripper_command) + + def rotate_gripper_with_delta(self, wrist_yaw=0.0, wrist_roll=0.0): + """ + Takes in relative wrist rotations targets and moves each wrist joint to the corresponding target. + Waits for 0.5 sec after issuing motion command + :param wrist_yaw: relative yaw for wrist in radians + :param wrist_roll: relative roll for wrist in radians + """ + print( + f"Rotating the wrist with the following relative rotations: yaw={wrist_yaw}, roll={wrist_roll}" + ) + + arm_joint_positions = self.get_arm_joint_positions(as_array=True) + # Maybe also wrap angles? + # Ordering: sh0, sh1, el0, el1, wr0, wr1 + joint_rotation_delta = np.array([0.0, 0.0, 0.0, 0.0, wrist_yaw, wrist_roll]) + new_arm_joint_states = np.add(arm_joint_positions, joint_rotation_delta) + self.set_arm_joint_positions(new_arm_joint_states) + time.sleep(0.5) + + def move_gripper_to_point(self, point, rotation): + """ + Moves EE to a point relative to body frame + :param point: XYZ location + :param rotation: Euler roll-pitch-yaw or WXYZ quaternion + :return: cmd_id + """ + if len(rotation) == 3: # roll pitch yaw Euler angles + roll, pitch, yaw = rotation + quat = geometry.EulerZXY(yaw=yaw, roll=roll, pitch=pitch).to_quaternion() + elif len(rotation) == 4: # w, x, y, z quaternion + w, x, y, z = rotation + quat = math_helpers.Quat(w=w, x=x, y=y, z=z) + else: + raise RuntimeError( + "rotation needs to have length 3 (euler) or 4 (quaternion)," + f"got {len(rotation)}" + ) + + hand_pose = math_helpers.SE3Pose(*point, quat) + hand_trajectory = trajectory_pb2.SE3Trajectory( + points=[trajectory_pb2.SE3TrajectoryPoint(pose=hand_pose.to_proto())] + ) + arm_cartesian_command = arm_command_pb2.ArmCartesianCommand.Request( + pose_trajectory_in_task=hand_trajectory, + root_frame_name=GRAV_ALIGNED_BODY_FRAME_NAME, + ) + + # Pack everything up in protos. + arm_command = arm_command_pb2.ArmCommand.Request( + arm_cartesian_command=arm_cartesian_command + ) + synchronized_command = synchronized_command_pb2.SynchronizedCommand.Request( + arm_command=arm_command + ) + command = robot_command_pb2.RobotCommand( + synchronized_command=synchronized_command + ) + cmd_id = self.command_client.robot_command(command) + + return cmd_id + + def block_until_arm_arrives(self, cmd_id, timeout_sec=5): + block_until_arm_arrives(self.command_client, cmd_id, timeout_sec=timeout_sec) + + def get_image_responses(self, sources, quality=None): + """Retrieve images from Spot's cameras + + :param sources: list containing camera uuids + :param quality: either an int or a list specifying what quality each source + should return its image with + :return: list containing bosdyn image response objects + """ + if quality is not None: + if isinstance(quality, int): + quality = [quality] * len(sources) + else: + assert len(quality) == len(sources) + sources = [build_image_request(src, q) for src, q in zip(sources, quality)] + image_responses = self.image_client.get_image(sources) + else: + image_responses = self.image_client.get_image_from_sources(sources) + + return image_responses + + def grasp_point_in_image( + self, + image_response, + pixel_xy=None, + timeout=10, + data_edge_timeout=2, + top_down_grasp=False, + horizontal_grasp=False, + ): + # If pixel location not provided, select the center pixel + if pixel_xy is None: + height = image_response.shot.image.rows + width = image_response.shot.image.cols + pixel_xy = [width // 2, height // 2] + + pick_vec = geometry_pb2.Vec2(x=pixel_xy[0], y=pixel_xy[1]) + grasp = manipulation_api_pb2.PickObjectInImage( + pixel_xy=pick_vec, + transforms_snapshot_for_camera=image_response.shot.transforms_snapshot, + frame_name_image_sensor=image_response.shot.frame_name_image_sensor, + camera_model=image_response.source.pinhole, + walk_gaze_mode=3, + ) + if top_down_grasp or horizontal_grasp: + if top_down_grasp: + # Add a constraint that requests that the x-axis of the gripper is + # pointing in the negative-z direction in the vision frame. + + # The axis on the gripper is the x-axis. + axis_on_gripper_ewrt_gripper = geometry_pb2.Vec3(x=1, y=0, z=0) + + # The axis in the vision frame is the negative z-axis + axis_to_align_with_ewrt_vo = geometry_pb2.Vec3(x=0, y=0, z=-1) + + else: + # Add a constraint that requests that the y-axis of the gripper is + # pointing in the positive-z direction in the vision frame. That means + # that the gripper is constrained to be rolled 90 degrees and pointed + # at the horizon. + + # The axis on the gripper is the y-axis. + axis_on_gripper_ewrt_gripper = geometry_pb2.Vec3(x=0, y=1, z=0) + + # The axis in the vision frame is the positive z-axis + axis_to_align_with_ewrt_vo = geometry_pb2.Vec3(x=0, y=0, z=1) + + grasp.grasp_params.grasp_params_frame_name = VISION_FRAME_NAME + # Add the vector constraint to our proto. + constraint = grasp.grasp_params.allowable_orientation.add() + constraint.vector_alignment_with_tolerance.axis_on_gripper_ewrt_gripper.CopyFrom( + axis_on_gripper_ewrt_gripper + ) + constraint.vector_alignment_with_tolerance.axis_to_align_with_ewrt_frame.CopyFrom( + axis_to_align_with_ewrt_vo + ) + + # Take anything within about 10 degrees for top-down or horizontal grasps. + constraint.vector_alignment_with_tolerance.threshold_radians = 1.0 * 2 + + # Ask the robot to pick up the object + grasp_request = manipulation_api_pb2.ManipulationApiRequest( + pick_object_in_image=grasp + ) + # Send the request + cmd_response = self.manipulation_api_client.manipulation_api_command( + manipulation_api_request=grasp_request + ) + + # Get feedback from the robot (WILL BLOCK TILL COMPLETION) + start_time = time.time() + success = False + while time.time() < start_time + timeout: + feedback_request = manipulation_api_pb2.ManipulationApiFeedbackRequest( + manipulation_cmd_id=cmd_response.manipulation_cmd_id + ) + + # Send the request + response = self.manipulation_api_client.manipulation_api_feedback_command( + manipulation_api_feedback_request=feedback_request + ) + + print( + "Current grasp_point_in_image state: ", + manipulation_api_pb2.ManipulationFeedbackState.Name( + response.current_state + ), + ) + + if ( + response.current_state + == manipulation_api_pb2.MANIP_STATE_GRASP_PLANNING_WAITING_DATA_AT_EDGE + ) and time.time() > start_time + data_edge_timeout: + break + elif ( + response.current_state + == manipulation_api_pb2.MANIP_STATE_GRASP_SUCCEEDED + ): + success = True + break + elif response.current_state in [ + manipulation_api_pb2.MANIP_STATE_GRASP_FAILED, + manipulation_api_pb2.MANIP_STATE_GRASP_PLANNING_NO_SOLUTION, + ]: + break + + time.sleep(0.25) + return success + + def grasp_hand_depth(self, *args, **kwargs): + image_responses = self.get_image_responses( + # [SpotCamIds.HAND_DEPTH_IN_HAND_COLOR_FRAME] + [SpotCamIds.HAND_COLOR] + ) + hand_image_response = image_responses[0] # only expecting one image + return self.grasp_point_in_image(hand_image_response, *args, **kwargs) + + def set_base_velocity( + self, + x_vel, + y_vel, + ang_vel, + vel_time, + disable_obstacle_avoidance=False, + return_cmd=False, + ): + body_tform_goal = math_helpers.SE2Velocity(x=x_vel, y=y_vel, angular=ang_vel) + params = spot_command_pb2.MobilityParams( + obstacle_params=spot_command_pb2.ObstacleParams( + disable_vision_body_obstacle_avoidance=disable_obstacle_avoidance, + disable_vision_foot_obstacle_avoidance=False, + disable_vision_foot_constraint_avoidance=False, + obstacle_avoidance_padding=0.05, # in meters + ) + ) + command = RobotCommandBuilder.synchro_velocity_command( + v_x=body_tform_goal.linear_velocity_x, + v_y=body_tform_goal.linear_velocity_y, + v_rot=body_tform_goal.angular_velocity, + params=params, + ) + + if return_cmd: + return command + + cmd_id = self.command_client.robot_command( + command, end_time_secs=time.time() + vel_time + ) + + return cmd_id + + def set_base_position( + self, + x_pos, + y_pos, + yaw, + end_time, + relative=False, + max_fwd_vel=2, + max_hor_vel=2, + max_ang_vel=np.pi / 2, + disable_obstacle_avoidance=False, + blocking=False, + ): + vel_limit = SE2VelocityLimit( + max_vel=SE2Velocity( + linear=Vec2(x=max_fwd_vel, y=max_hor_vel), angular=max_ang_vel + ), + min_vel=SE2Velocity( + linear=Vec2(x=-max_fwd_vel, y=-max_hor_vel), angular=-max_ang_vel + ), + ) + params = spot_command_pb2.MobilityParams( + vel_limit=vel_limit, + obstacle_params=spot_command_pb2.ObstacleParams( + disable_vision_body_obstacle_avoidance=disable_obstacle_avoidance, + disable_vision_foot_obstacle_avoidance=False, + disable_vision_foot_constraint_avoidance=False, + obstacle_avoidance_padding=0.05, # in meters + ), + ) + curr_x, curr_y, curr_yaw = self.get_xy_yaw(use_boot_origin=True) + coors = np.array([x_pos, y_pos, 1.0]) + if relative: + local_T_global = self._get_local_T_global(curr_x, curr_y, curr_yaw) + x, y, w = local_T_global.dot(coors) + global_x_pos, global_y_pos = x / w, y / w + global_yaw = wrap_heading(curr_yaw + yaw) + else: + global_x_pos, global_y_pos, global_yaw = self.xy_yaw_home_to_global( + x_pos, y_pos, yaw + ) + robot_cmd = RobotCommandBuilder.synchro_se2_trajectory_point_command( + goal_x=global_x_pos, + goal_y=global_y_pos, + goal_heading=global_yaw, + frame_name=VISION_FRAME_NAME, + params=params, + ) + cmd_id = self.command_client.robot_command( + robot_cmd, end_time_secs=time.time() + end_time + ) + + if blocking: + cmd_status = None + while cmd_status != 1: + time.sleep(0.1) + feedback_resp = self.get_cmd_feedback(cmd_id) + cmd_status = ( + feedback_resp.feedback.synchronized_feedback + ).mobility_command_feedback.se2_trajectory_feedback.status + return None + + return cmd_id + + def get_robot_state(self): + return self.robot_state_client.get_robot_state() + + def get_battery_charge(self): + state = self.get_robot_state() + return state.power_state.locomotion_charge_percentage.value + + def roll_over(self, roll_over_left=True): + if roll_over_left: + dir_hint = basic_command_pb2.BatteryChangePoseCommand.Request.HINT_LEFT + else: + dir_hint = basic_command_pb2.BatteryChangePoseCommand.Request.HINT_RIGHT + cmd = RobotCommandBuilder.battery_change_pose_command(dir_hint=dir_hint) + self.command_client.robot_command(cmd) + + def sit(self): + cmd = RobotCommandBuilder.synchro_sit_command() + self.command_client.robot_command(cmd) + + def get_arm_proprioception(self, robot_state=None): + """Return state of each of the 6 joints of the arm""" + if robot_state is None: + robot_state = self.robot_state_client.get_robot_state() + arm_joint_states = OrderedDict( + { + i.name[len("arm0.") :]: i + for i in robot_state.kinematic_state.joint_states + if i.name in ARM_6DOF_NAMES + } + ) + + return arm_joint_states + + def get_proprioception(self, robot_state=None): + """Return state of each of the 6 joints of the arm""" + if robot_state is None: + robot_state = self.robot_state_client.get_robot_state() + joint_states = OrderedDict( + {i.name: i for i in robot_state.kinematic_state.joint_states} + ) + + return joint_states + + def get_arm_joint_positions(self, as_array=True): + """ + Gives in joint positions of the arm in radians in the following order + Ordering: sh0, sh1, el0, el1, wr0, wr1 + :param as_array: bool, True for output as an np.array, False for list + :return: 6 element data structure (np.array or list) of joint positions as radians + """ + arm_joint_states = self.get_arm_proprioception() + arm_joint_positions = np.fromiter( + (arm_joint_states[joint].position.value for joint in arm_joint_states), + float, + ) + + if as_array: + return arm_joint_positions + return arm_joint_positions.tolist() + + def set_arm_joint_positions( + self, positions, travel_time=1.0, max_vel=2.5, max_acc=15, return_cmd=False + ): + """ + Takes in 6 joint targets and moves each arm joint to the corresponding target. + Ordering: sh0, sh1, el0, el1, wr0, wr1 + :param positions: np.array or list of radians + :param travel_time: how long execution should take + :param max_vel: max allowable velocity + :param max_acc: max allowable acceleration + :return: cmd_id + """ + sh0, sh1, el0, el1, wr0, wr1 = positions + traj_point = RobotCommandBuilder.create_arm_joint_trajectory_point( + sh0, sh1, el0, el1, wr0, wr1, travel_time + ) + arm_joint_traj = arm_command_pb2.ArmJointTrajectory( + points=[traj_point], + maximum_velocity=wrappers_pb2.DoubleValue(value=max_vel), + maximum_acceleration=wrappers_pb2.DoubleValue(value=max_acc), + ) + command = make_robot_command(arm_joint_traj) + + if return_cmd: + return command + + cmd_id = self.command_client.robot_command(command) + + return cmd_id + + def set_base_vel_and_arm_pos( + self, + x_vel, + y_vel, + ang_vel, + arm_positions, + travel_time, + disable_obstacle_avoidance=False, + ): + base_cmd = self.set_base_velocity( + x_vel, + y_vel, + ang_vel, + vel_time=travel_time, + disable_obstacle_avoidance=disable_obstacle_avoidance, + return_cmd=True, + ) + arm_cmd = self.set_arm_joint_positions( + arm_positions, travel_time=travel_time, return_cmd=True + ) + synchro_command = RobotCommandBuilder.build_synchro_command(base_cmd, arm_cmd) + cmd_id = self.command_client.robot_command( + synchro_command, end_time_secs=time.time() + travel_time + ) + return cmd_id + + def get_xy_yaw(self, use_boot_origin=False, robot_state=None): + """ + Returns the relative x and y distance from start, as well as relative heading + """ + if robot_state is None: + robot_state = self.robot_state_client.get_robot_state() + robot_state_kin = robot_state.kinematic_state + self.body = get_vision_tform_body(robot_state_kin.transforms_snapshot) + robot_tform = self.body + yaw = math_helpers.quat_to_eulerZYX(robot_tform.rotation)[0] + if self.global_T_home is None or use_boot_origin: + return robot_tform.x, robot_tform.y, yaw + return self.xy_yaw_global_to_home(robot_tform.x, robot_tform.y, yaw) + + def xy_yaw_global_to_home(self, x, y, yaw): + x, y, w = self.global_T_home.dot(np.array([x, y, 1.0])) + x, y = x / w, y / w + + return x, y, wrap_heading(yaw - self.robot_recenter_yaw) + + def xy_yaw_home_to_global(self, x, y, yaw): + local_T_global = np.linalg.inv(self.global_T_home) + x, y, w = local_T_global.dot(np.array([x, y, 1.0])) + x, y = x / w, y / w + + return x, y, wrap_heading(self.robot_recenter_yaw - yaw) + + def _get_local_T_global(self, x=None, y=None, yaw=None): + if x is None: + x, y, yaw = self.get_xy_yaw(use_boot_origin=True) + # Create offset transformation matrix + local_T_global = np.array( + [ + [np.cos(yaw), -np.sin(yaw), x], + [np.sin(yaw), np.cos(yaw), y], + [0.0, 0.0, 1.0], + ] + ) + return local_T_global + + def home_robot(self): + x, y, yaw = self.get_xy_yaw(use_boot_origin=True) + local_T_global = self._get_local_T_global() + self.global_T_home = np.linalg.inv(local_T_global) + self.robot_recenter_yaw = yaw + + as_string = list(self.global_T_home.flatten()) + [yaw] + as_string = f"{as_string}"[1:-1] # [1:-1] removes brackets + with open(HOME_TXT, "w") as f: + f.write(as_string) + self.loginfo(f"Wrote:\n{as_string}\nto: {HOME_TXT}") + + def get_base_transform_to(self, child_frame): + kin_state = self.robot_state_client.get_robot_state().kinematic_state + kin_state = kin_state.transforms_snapshot.child_to_parent_edge_map.get( + child_frame + ).parent_tform_child + return kin_state.position, kin_state.rotation + + def dock(self, dock_id, home_robot=False): + blocking_dock_robot(self.robot, dock_id) + if home_robot: + self.home_robot() + + def undock(self): + blocking_undock(self.robot) + + +class SpotLease: + """ + A class that supports execution with Python's "with" statement for safe return of + the lease and settle-then-estop upon exit. Grants control of the Spot's motors. + """ + + def __init__(self, spot, hijack=False): + self.lease_client = spot.robot.ensure_client( + bosdyn.client.lease.LeaseClient.default_service_name + ) + if hijack: + self.lease = self.lease_client.take() + else: + self.lease = self.lease_client.acquire() + self.lease_keep_alive = bosdyn.client.lease.LeaseKeepAlive(self.lease_client) + self.spot = spot + + def __enter__(self): + return self.lease + + def __exit__(self, exc_type, exc_val, exc_tb): + # Exit the LeaseKeepAlive object + self.lease_keep_alive.__exit__(exc_type, exc_val, exc_tb) + # Return the lease + self.lease_client.return_lease(self.lease) + self.spot.loginfo("Returned the lease.") + # Clear lease from Spot object + self.spot.spot_lease = None + + def create_sublease(self): + return self.lease.create_sublease() + + +def make_robot_command(arm_joint_traj): + """Helper function to create a RobotCommand from an ArmJointTrajectory. + The returned command will be a SynchronizedCommand with an ArmJointMoveCommand + filled out to follow the passed in trajectory.""" + + joint_move_command = arm_command_pb2.ArmJointMoveCommand.Request( + trajectory=arm_joint_traj + ) + arm_command = arm_command_pb2.ArmCommand.Request( + arm_joint_move_command=joint_move_command + ) + sync_arm = synchronized_command_pb2.SynchronizedCommand.Request( + arm_command=arm_command + ) + arm_sync_robot_cmd = robot_command_pb2.RobotCommand(synchronized_command=sync_arm) + return RobotCommandBuilder.build_synchro_command(arm_sync_robot_cmd) + + +def image_response_to_cv2(image_response, reorient=True): + if image_response.shot.image.pixel_format == image_pb2.Image.PIXEL_FORMAT_DEPTH_U16: + dtype = np.uint16 + else: + dtype = np.uint8 + # img = np.fromstring(image_response.shot.image.data, dtype=dtype) + img = np.frombuffer(image_response.shot.image.data, dtype=dtype) + if image_response.shot.image.format == image_pb2.Image.FORMAT_RAW: + img = img.reshape( + image_response.shot.image.rows, image_response.shot.image.cols + ) + else: + img = cv2.imdecode(img, -1) + + if reorient and image_response.source.name in SHOULD_ROTATE: + img = np.rot90(img, k=3) + + return img + + +def scale_depth_img(img, min_depth=0.0, max_depth=10.0, as_img=False): + min_depth, max_depth = min_depth * 1000, max_depth * 1000 + img_copy = np.clip(img.astype(np.float32), a_min=min_depth, a_max=max_depth) + img_copy = (img_copy - min_depth) / (max_depth - min_depth) + if as_img: + img_copy = cv2.cvtColor((255.0 * img_copy).astype(np.uint8), cv2.COLOR_GRAY2BGR) + + return img_copy + + +def draw_crosshair(img): + height, width = img.shape[:2] + cx, cy = width // 2, height // 2 + img = cv2.circle( + img, + center=(cx, cy), + radius=5, + color=(0, 0, 255), + thickness=1, + ) + + return img + + +def wrap_heading(heading): + """Ensures input heading is between -180 an 180; can be float or np.ndarray""" + return (heading + np.pi) % (2 * np.pi) - np.pi diff --git a/bd_spot_wrapper/spot_wrapper/stand.py b/bd_spot_wrapper/spot_wrapper/stand.py new file mode 100644 index 00000000..0a3d061b --- /dev/null +++ b/bd_spot_wrapper/spot_wrapper/stand.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import time + +from spot_wrapper.spot import Spot + + +def main(spot: Spot): + """Make Spot stand""" + spot.power_on() + spot.blocking_stand() + + # Wait 3 seconds to before powering down... + while True: + pass + time.sleep(3) + spot.power_off() + + +if __name__ == "__main__": + spot = Spot("BasicStandingClient") + with spot.get_lease() as lease: + main(spot) diff --git a/bd_spot_wrapper/spot_wrapper/utils.py b/bd_spot_wrapper/spot_wrapper/utils.py new file mode 100644 index 00000000..5db4f2fa --- /dev/null +++ b/bd_spot_wrapper/spot_wrapper/utils.py @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import subprocess + +import cv2 +import numpy as np + + +def say(text): + try: + text = text.replace("_", " ") + text = f'"{text}"' + cmd = f"/usr/bin/festival -b '(voice_cmu_us_slt_arctic_hts)' '(SayText {text})'" + subprocess.Popen(cmd, shell=True) + except Exception: + pass + print(f'Saying: "{text}"') + + +def resize_to_tallest(imgs, hstack=False): + tallest = max([i.shape[0] for i in imgs]) + for idx, i in enumerate(imgs): + height, width = i.shape[:2] + if height != tallest: + new_width = int(width * (tallest / height)) + imgs[idx] = cv2.resize(i, (new_width, tallest)) + if hstack: + return np.hstack(imgs) + return imgs + + +def inflate_erode(mask, size=50): + mask_copy = mask.copy() + mask_copy = cv2.blur(mask_copy, (size, size)) + mask_copy[mask_copy > 0] = 255 + mask_copy = cv2.blur(mask_copy, (size, size)) + mask_copy[mask_copy < 255] = 0 + + return mask_copy + + +def erode_inflate(mask, size=20): + mask_copy = mask.copy() + mask_copy = cv2.blur(mask_copy, (size, size)) + mask_copy[mask_copy < 255] = 0 + mask_copy = cv2.blur(mask_copy, (size, size)) + mask_copy[mask_copy > 0] = 255 + + return mask_copy + + +def contour_mask(mask): + cnt, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) + new_mask = np.zeros(mask.shape, dtype=np.uint8) + max_area = 0 + max_index = 0 + for idx, c in enumerate(cnt): + area = cv2.contourArea(c) + if area > max_area: + max_area = area + max_index = idx + cv2.drawContours(new_mask, cnt, max_index, 255, cv2.FILLED) + + return new_mask + + +def color_bbox(img, just_get_bbox=False): + """Makes a bbox around a white object""" + # Filter out non-white + sensitivity = 80 + upper_white = np.array([255, 255, 255]) + lower_white = upper_white - sensitivity + color_mask = cv2.inRange(img, lower_white, upper_white) + + # Filter out little bits of white + color_mask = inflate_erode(color_mask) + color_mask = erode_inflate(color_mask) + + # Only use largest contour + color_mask = contour_mask(color_mask) + + # Calculate bbox + x, y, w, h = cv2.boundingRect(color_mask) + + if just_get_bbox: + return x, y, w, h + + height, width = color_mask.shape + cx = (x + w / 2.0) / width + cy = (y + h / 2.0) / height + + # Create bbox mask + bbox_mask = np.zeros([height, width, 1], dtype=np.float32) + bbox_mask[y : y + h, x : x + w] = 1.0 + + # Determine if bbox intersects with central crosshair + crosshair_in_bbox = x < width // 2 < x + w and y < height // 2 < y + h + + return bbox_mask, cx, cy, crosshair_in_bbox diff --git a/bd_spot_wrapper/spot_wrapper/view_arm_proprioception.py b/bd_spot_wrapper/spot_wrapper/view_arm_proprioception.py new file mode 100644 index 00000000..8d210eb3 --- /dev/null +++ b/bd_spot_wrapper/spot_wrapper/view_arm_proprioception.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import time + +import numpy as np +from spot_wrapper.spot import Spot + + +def main(spot: Spot): + while True: + arm_prop = spot.get_arm_proprioception() + current_joint_positions = np.array( + [v.position.value for v in arm_prop.values()] + ) + spot.loginfo(", ".join([str(i) for i in np.rad2deg(current_joint_positions)])) + spot.loginfo([v.name for v in arm_prop.values()]) + time.sleep(1 / 30) + + +if __name__ == "__main__": + spot = Spot("ArmJointControl") + main(spot) diff --git a/bd_spot_wrapper/spot_wrapper/view_camera.py b/bd_spot_wrapper/spot_wrapper/view_camera.py new file mode 100644 index 00000000..dd7170a2 --- /dev/null +++ b/bd_spot_wrapper/spot_wrapper/view_camera.py @@ -0,0 +1,77 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +# mypy: ignore-errors +import argparse +import time +from collections import deque + +import cv2 +import numpy as np +from spot_wrapper.spot import ( + Spot, + SpotCamIds, + draw_crosshair, + image_response_to_cv2, + scale_depth_img, +) +from spot_wrapper.utils import color_bbox, resize_to_tallest + +MAX_HAND_DEPTH = 3.0 +MAX_HEAD_DEPTH = 10.0 +DETECT_LARGEST_WHITE_OBJECT = False + + +def main(spot: Spot): + parser = argparse.ArgumentParser() + parser.add_argument("-d", "--no-display", action="store_true") + parser.add_argument("-q", "--quality", type=int) + args = parser.parse_args() + window_name = "Spot Camera Viewer" + time_buffer = deque(maxlen=10) + sources = [ + SpotCamIds.FRONTRIGHT_DEPTH, + SpotCamIds.FRONTLEFT_DEPTH, + SpotCamIds.HAND_DEPTH, + SpotCamIds.HAND_COLOR, + ] + try: + while True: + start_time = time.time() + + # Get Spot camera image + image_responses = spot.get_image_responses(sources, quality=args.quality) + imgs = [] + for image_response, source in zip(image_responses, sources): + img = image_response_to_cv2(image_response, reorient=True) + if "depth" in source: + max_depth = MAX_HAND_DEPTH if "hand" in source else MAX_HEAD_DEPTH + img = scale_depth_img(img, max_depth=max_depth, as_img=True) + elif source is SpotCamIds.HAND_COLOR: + img = draw_crosshair(img) + if DETECT_LARGEST_WHITE_OBJECT: + x, y, w, h = color_bbox(img, just_get_bbox=True) + cv2.rectangle(img, (x, y), (x + w, y + h), (0, 255, 0), 2) + + imgs.append(img) + + # Make sure all imgs are same height + img = resize_to_tallest(imgs, hstack=True) + + if not args.no_display: + cv2.imshow(window_name, img) + cv2.waitKey(1) + + time_buffer.append(time.time() - start_time) + print("Avg FPS:", 1 / np.mean(time_buffer)) + finally: + if not args.no_display: + cv2.destroyWindow(window_name) + + +if __name__ == "__main__": + spot = Spot("ViewCamera") + # We don't need a lease because we're passively observing images (no motor ctrl) + main(spot) diff --git a/bd_spot_wrapper/spot_wrapper/view_camera_and_record.py b/bd_spot_wrapper/spot_wrapper/view_camera_and_record.py new file mode 100644 index 00000000..f8982e32 --- /dev/null +++ b/bd_spot_wrapper/spot_wrapper/view_camera_and_record.py @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +# mypy: ignore-errors +import argparse +import time +from collections import deque + +import cv2 +import numpy as np +from spot_wrapper.spot import ( + Spot, + SpotCamIds, + draw_crosshair, + image_response_to_cv2, + scale_depth_img, +) +from spot_wrapper.utils import color_bbox, resize_to_tallest + +MAX_HAND_DEPTH = 3.0 +MAX_HEAD_DEPTH = 10.0 +DETECT_LARGEST_WHITE_OBJECT = False + +FOUR_CC = cv2.VideoWriter_fourcc(*"MP4V") +FPS = 30 + + +def main(spot: Spot): + parser = argparse.ArgumentParser() + parser.add_argument("-d", "--no-display", action="store_true") + parser.add_argument("-q", "--quality", type=int) + args = parser.parse_args() + window_name = "Spot Camera Viewer" + time_buffer = deque(maxlen=10) + sources = [ + # SpotCamIds.FRONTRIGHT_DEPTH, + # SpotCamIds.FRONTLEFT_DEPTH, + # SpotCamIds.HAND_DEPTH, + SpotCamIds.HAND_COLOR, + ] + try: + all_imgs = [] + k = 0 + while True: + start_time = time.time() + + # Get Spot camera image + image_responses = spot.get_image_responses(sources, quality=args.quality) + imgs = [] + for image_response, source in zip(image_responses, sources): + img = image_response_to_cv2(image_response, reorient=True) + if "depth" in source: + max_depth = MAX_HAND_DEPTH if "hand" in source else MAX_HEAD_DEPTH + img = scale_depth_img(img, max_depth=max_depth, as_img=True) + elif source is SpotCamIds.HAND_COLOR: + # img = draw_crosshair(img) + if DETECT_LARGEST_WHITE_OBJECT: + x, y, w, h = color_bbox(img, just_get_bbox=True) + cv2.rectangle(img, (x, y), (x + w, y + h), (0, 255, 0), 2) + + imgs.append(img) + all_imgs.append(img) + w, h = img.shape[:2] + + # Make sure all imgs are same height + img = resize_to_tallest(imgs, hstack=True) + + if not args.no_display: + cv2.imshow(window_name, img) + cv2.waitKey(1) + + if k % 10 == 0: + cv2.imwrite(f"videos_april/img_{k}.jpg", img) + + # if k % 50 == 0 and k > 0: + # new_video = cv2.VideoWriter( + # f'videos_april/video_{video_index}.mp4', + # -1, + # FPS, + # (w, h), + # ) + # for img in all_imgs: + # new_video.write(img) + + # new_video.release() + # all_imgs = [] + # break + + k += 1 + time_buffer.append(time.time() - start_time) + print("Avg FPS:", 1 / np.mean(time_buffer)) + finally: + if not args.no_display: + cv2.destroyWindow(window_name) + + +if __name__ == "__main__": + spot = Spot("ViewCamera") + # We don't need a lease because we're passively observing images (no motor ctrl) + main(spot) diff --git a/installation/ISSUES.md b/installation/ISSUES.md new file mode 100644 index 00000000..6edb3ff6 --- /dev/null +++ b/installation/ISSUES.md @@ -0,0 +1,220 @@ +# Common Issues + +The following are some of the most commonly seen issues + +## If you face an issue saying "The detected CUDA version (12.1) mismatches the version that was used to compile" : +```bash +RuntimeError: + The detected CUDA version (12.1) mismatches the version that was used to compile + PyTorch (11.3). Please make sure to use the same CUDA versions. + + [end of output] + +note: This error originates from a subprocess, and is likely not a problem with pip. +``` +Root cause: Your system has an nvidia-driver (with CUDA=12.1 in my case). But we used a different CUDA(=11.3) to compile pytorch. Installation of detectron2 python package does not like this and will complain. + +Tried solution : Delete all nvidia drivers from system (root) and install a new one. It would be better to use an **11.x** driver. This is a solution that we use, but you can use a different method. + +## How to find IP address of local computer +Find local ip of the computer using `ifconfig`. Try to find the profile with flags ``, the *inet* corresponding to that profile is the ip of your computer. + +```bash +(spot_ros) user@linux-machine:~$ ifconfig +enp69s0: flags=4163 mtu 1500 <---------------------------- This is the profile we are looking at + inet 192.168.1.6 netmask 255.255.255.0 broadcast 192.168.1.255 + ... + ... + +lo: flags=73 mtu 65536 + inet 127.0.0.1 netmask 255.0.0.0 + ... + ... + +``` + +## Issues with `spot_rl_launch_local`: + + If you are seeing the following error while running `spot_rl_launch_local` then you are missing `tmux` package. + ```bash + (spot_ros) user@linux-machine:~/spot-sim2real$ spot_rl_launch_local + Killing all tmux sessions... + /path/to/local_only.sh: line 2: tmux: command not found + /path/to/local_only.sh: line 3: tmux: command not found + /path/to/local_only.sh: line 4: tmux: command not found + /path/to/local_only.sh: line 5: tmux: command not found + Starting roscore tmux... + /path/to/local_only.sh: line 8: tmux: command not found + Starting other tmux nodes.. + /path/to/local_only.sh: line 10: tmux: command not found + /path/to/local_only.sh: line 11: tmux: command not found + /path/to/local_only.sh: line 12: tmux: command not found + /path/to/local_only.sh: line 14: tmux: command not found + ``` + Easy fix : + 1. Install tmux using `sudo apt install tmux` + +### Debugging strategies for `spot_rl_launch_local` if any one of the 4 sessions are dying before 70 seconds + 1. `roscore` + 1. If you see that roscore is dying before 70 seconds, it means that the ip from `ROS_IP`/`ROS_HOSTNAME` and/or `ROS_MASTER_URI` is not matching local ip if your computer, in other words the local ip of your computer has changed. + Follow the instructions regarding ip described above to update the local IP. + 2. Try running following command to ensure roscore is up and running. + ```bash + rostopic list + ``` + 2. `img_publishers` This is an important node. If this node dies, there could be several reasons. + 1. Roscore is not running. If roscore dies, then img_publishers die too. Fixing `roscore` will resolve this particular root-cause too. + 2. Computer is not connected to robot. You can clarify if this is the case by tring to ping the robot `ping $SPOT_IP` + 3. Code specific failure. To debug this, you should try running the following command in the terminal to find out the root cause + ```bash + $CONDA_PREFIX/bin/python -m spot_rl.utils.img_publishers --local + ``` + Once you have fixed the issue, you need to kill all `img_publishers` nodes that are running `spot_rl_launch_local`, this can be done using `htop` + 4. If failure is due to missing `waypoints.yaml` file, then [follow these steps to generate the `waypoints.yaml` file](/README.md#video_game-instructions-to-record-waypoints-use-joystick-to-move-robot-around) + 5. If you face an issue regarding `"Block8 has no module relu"`, [follow these steps described in ISSUES.md](/installation/ISSUES.md#if-you-face-an-issue-saying-block8-has-no-module-relu) + 3. `proprioception` + 1. This node dies sometimes due to roscore taking quite a while to start up. Re-running `spot_rl_launch_local` should fix this in most cases. + 2. If it still does not get fixed, run this command on a new terminal + ```bash + $CONDA_PREFIX/bin/python -m spot_rl.utils.helper_nodes --proprioception + ``` + Once you have fixed the issue, you need to kill all `proprioception` nodes that are running before running `spot_rl_launch_local`, this can be done using `htop` + 3. If failure is due to missing `waypoints.yaml` file, then [follow these steps to generate the `waypoints.yaml` file](/README.md#video_game-instructions-to-record-waypoints-use-joystick-to-move-robot-around) + 4. `tts` If this node dies, we would be surprised too. In that case, try re-running `spot_rl_launch_local`. If it still fails, don't bother fixing this one :p + + +## If you face an issue saying "Block8 has no module relu": + +Issue: +``` +(spot_ros) user@linux-machine:~/spot-sim2real$ $CONDA_PREFIX/bin/python -m spot_rl.utils.img_publishers --local + +RuntimeError: +Module 'Block8' has no attribute 'relu' : + File "/home/test/miniconda3/envs/spot_ros/lib/python3.8/site-packages/pretrainedmodels/models/inceptionresnetv2.py", line 231 + out = out * self.scale + x + if not self.noReLU: + out = self.relu(out) + ~~~~~~~~~ <--- HERE + return out + +``` +Soln : Then open the file `inceptionresnetv2.py` at `/home/test/miniconda3/envs/spot_ros/lib/python3.8/site-packages/pretrainedmodels/models/inceptionresnetv2.py` in editor of your choice and make it look like the following. + +Before the change, the code should look like follows: +```bash +204 class Block8(nn.Module): +205 +206 def __init__(self, scale=1.0, noReLU=False): +207 super(Block8, self).__init__() +208 +209 self.scale = scale +210 self.noReLU = noReLU +211 +212 self.branch0 = BasicConv2d(2080, 192, kernel_size=1, stride=1) +213 +214 self.branch1 = nn.Sequential( +215 BasicConv2d(2080, 192, kernel_size=1, stride=1), +216 BasicConv2d(192, 224, kernel_size=(1,3), stride=1, padding=(0,1)), +217 BasicConv2d(224, 256, kernel_size=(3,1), stride=1, padding=(1,0)) +218 ) +219 +220 self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1) +221 if not self.noReLU: +222 self.relu = nn.ReLU(inplace=False) +223 +``` + +After the change, the code should look like follows: +```bash +204 class Block8(nn.Module): +205 +206 def __init__(self, scale=1.0, noReLU=False): +207 super(Block8, self).__init__() +208 +209 self.scale = scale +210 self.noReLU = noReLU +211 +212 self.branch0 = BasicConv2d(2080, 192, kernel_size=1, stride=1) +213 +214 self.branch1 = nn.Sequential( +215 BasicConv2d(2080, 192, kernel_size=1, stride=1), +216 BasicConv2d(192, 224, kernel_size=(1,3), stride=1, padding=(0,1)), +217 BasicConv2d(224, 256, kernel_size=(3,1), stride=1, padding=(1,0)) +218 ) +219 +220 self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1) +221 # if not self.noReLU: +222 self.relu = nn.ReLU(inplace=False) +223 +``` + +**Please only change code on line `221` & `222` inside `inceptionresnetv2.py`** + + ## Issues while downloading weights + If you see issues like + ```bash + Archive: spot-sim2real-data/weight/weights.zip End-of-central-directory signature not found. Either this file is not a zipfile, or it constitutes one disk of a multi-part archive. In the latter case the central directory and zipfile comment will be found on the last disk(s) of this archive. unzip: cannot find zipfile directory in one of spot-sim2real-data/weight/weights.zip or spot-sim2real-data/weight/weights.zip.zip, and cannot find spot-sim2real-data/weight/weights.zip.ZIP, period + ``` + + OR + ```bash + Ick! 0x73726576 + ``` + + It means git-lfs has not been installed properly on your system. You can install git-lfs by the following commands + ```bash + curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash + sudo apt-get install git-lfs + git-lfs install + ``` + + ## Issues while running setup.py for Detectron2 + If you face the following issues while running setup.py for Detectron2, it indicates either gcc & g++ are missing from your system or are not linked properly. + + ```bash +error: command 'gcc' failed: No such file or directory + [end of output] + + note: This error originates from a subprocess, and is likely not a problem with pip. + ERROR: Failed building wheel for pycocotools +Failed to build pycocotools +ERROR: Could not build wheels for pycocotools, which is required to install pyproject.toml-based projects +``` + + Easy fix : + 1. Remove gcc & g++ versions if they exist, + 2. Install gcc -> `sudo apt-get install gcc` + 3. Install g++ -> `sudo apt-get install g++` + 4. Then retry running the same setup.py command + + ## Issues while running setup.py for Habitat-lab + If you face the following issues while running setup.py for Habitat-lab, it is because a running the setup script tried to install a newer version of `tensorflow` which depends on newer version of `numpy` which conflicts with already existing `numpy` version in our v-env. + ```bash + Installed /home/user/miniconda3/envs/spot_ros/lib/python3.8/site-packages/tensorflow-2.13.0-py3.8-linux-x86_64.egg + Searching for tensorflow-estimator<2.14,>=2.13.0 + Reading https://pypi.org/simple/tensorflow-estimator/ + Downloading https://files.pythonhosted.org/packages/72/5c/c318268d96791c6222ad7df1651bbd1b2409139afeb6f468c0f327177016/tensorflow_estimator-2.13.0-py2.py3-none-any.whl#sha256=6f868284eaa654ae3aa7cacdbef2175d0909df9fcf11374f5166f8bf475952aa + Best match: tensorflow-estimator 2.13.0 + Processing tensorflow_estimator-2.13.0-py2.py3-none-any.whl + Installing tensorflow_estimator-2.13.0-py2.py3-none-any.whl to /home/user/miniconda3/envs/spot_ros/lib/python3.8/site-packages + Adding tensorflow-estimator 2.13.0 to easy-install.pth file + + Installed /home/user/miniconda3/envs/spot_ros/lib/python3.8/site-packages/tensorflow_estimator-2.13.0-py3.8.egg + Searching for tensorboard<2.14,>=2.13 + Reading https://pypi.org/simple/tensorboard/ + Downloading https://files.pythonhosted.org/packages/67/f2/e8be5599634ff063fa2c59b7b51636815909d5140a26df9f02ce5d99b81a/tensorboard-2.13.0-py3-none-any.whl#sha256=ab69961ebddbddc83f5fa2ff9233572bdad5b883778c35e4fe94bf1798bd8481 + Best match: tensorboard 2.13.0 + Processing tensorboard-2.13.0-py3-none-any.whl + Installing tensorboard-2.13.0-py3-none-any.whl to /home/user/miniconda3/envs/spot_ros/lib/python3.8/site-packages + Adding tensorboard 2.13.0 to easy-install.pth file + Installing tensorboard script to /home/user/miniconda3/envs/spot_ros/bin + + Installed /home/user/miniconda3/envs/spot_ros/lib/python3.8/site-packages/tensorboard-2.13.0-py3.8.egg + error: numpy 1.21.6 is installed but numpy<=1.24.3,>=1.22 is required by {'tensorflow'} + ``` + + Easy fix : + 1. Clear all caches from mamba - `mamba clean -f -a` + 2. Install spefic tensorflow version - `pip install tensorflow==2.9.0` + 3. Then retry running the same setup.py command diff --git a/installation/SETUP_INSTRUCTIONS.md b/installation/SETUP_INSTRUCTIONS.md new file mode 100644 index 00000000..54e01b71 --- /dev/null +++ b/installation/SETUP_INSTRUCTIONS.md @@ -0,0 +1,237 @@ +# Setup Instructions + +### Clone the repo + +```bash +git clone git@github.com:facebookresearch/spot-sim2real.git +cd spot-sim2real/ +git submodule update --init --recursive +``` + +### Update the system packages and install required pkgs + +```bash +sudo apt-get update +sudo apt-get install gcc +sudo apt-get install g++ +sudo apt install tmux +``` + +### Install Miniconda at /home/ + +```bash +# Download miniconda +cd ~/ && wget https://repo.anaconda.com/miniconda/Miniconda3-py39_23.3.1-0-Linux-x86_64.sh + +# Run the script +bash Miniconda3-py39_23.3.1-0-Linux-x86_64.sh +``` + +### Follow these instructions while installing Miniconda + +```bash +Do you accept the license terms? [yes|no] +[no] >>> +Please answer 'yes' or 'no':' -- + +Miniconda3 will now be installed into this location: +/home//miniconda3 -- + +Do you wish the installer to initialize Miniconda3 +by running conda init? [yes|no] +[no] -- /miniconda3/bin:$PATH' >> ~/.bashrc +source ~/.bashrc + +# Check +conda --version +``` + +### Install mamba + +```bash +# Install +conda install -c conda-forge mamba + +# Check +mamba --version +``` + +### Create environment (Takes a while) + +```bash +# cd into the cloned repository +cd ~/spot-sim2real/ + +# Use the yaml file to setup the environemnt +mamba env create -f installation/environment.yml +source ~/.bashrc +mamba init + +# Update bashrc to activate this environment +echo 'mamba activate spot_ros' >> ~/.bashrc +source ~/.bashrc +``` + +### Install torch and cuda packages (**through the installation preview, ensure all of the following packages are installed as CUDA versions and not CPU versions**) + +```bash +mamba install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch + +# If this command fails for error "Could not solve for environment specs", run the following + +# mamba install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch -c conda-forge +``` + +### Add required channels + +```bash +conda config --env --add channels conda-forge +conda config --env --add channels robostack-experimental +conda config --env --add channels robostack +conda config --env --set channel_priority strict +``` + +### Setup bd_spot_wrapper + +```bash +# Generate module +cd bd_spot_wrapper/ && python generate_executables.py +pip install -e . && cd ../ +``` + +### Setup spot_rl_experiments + +```bash +# Generate module +cd spot_rl_experiments/ && python generate_executables.py +pip install -e . + +# Get git lfs (large file system) +sudo apt-get install git-lfs +git lfs install + +# Download weights (need only once) +git clone https://huggingface.co/spaces/jimmytyyang/spot-sim2real-data +unzip spot-sim2real-data/weight/weights.zip && rm -rf spot-sim2real-data && cd ../ +``` + +### Setup MaskRCNN + +```bash +# Generate module +cd third_party/mask_rcnn_detectron2/ && pip install -e . + +# Setup detectron +git clone git@github.com:facebookresearch/detectron2.git +pip install -e detectron2 && cd ../../ +``` +If you face any issues in this step, refer to [this section in ISSUES.md](/installation/ISSUES.md#issues-while-running-setuppy-for-detectron2) + +### Setup DeblurGAN + +```bash +# Generate module +cd third_party/DeblurGANv2/ && pip install -e . && cd ../../ +``` + +### Setup Habitat-lab + +```bash +cd third_party/habitat-lab/ +mamba install -c aihabitat habitat-sim==0.2.1 -y +python setup.py develop --all +cd ../../ +``` +If you face any issues in this step, refer to [this section in ISSUES.md](/installation/ISSUES.md#issues-while-running-setuppy-for-habitat-lab) + +### Download inceptionresnet weights + +```bash +# Create dir to store weights if it does not exist +mkdir -p ~/.cache/torch/hub/checkpoints + +# Get weights (May take a while) +wget http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth -O ~/.cache/torch/hub/checkpoints/inceptionresnetv2-520b38e4.pth --no-check-certificate +``` + +### Ensure you have port-audio library for sounddevice (useful for connecting external microphones for speech-to-text) + +```bash +sudo apt-get install libportaudio2 +``` + + +### Setting ROS env variables +* If using **ROS on only 1 computer** (i.e. you don't need 2 or more machines in the ROS network), follow these steps + ```bash + echo 'export ROS_HOSTNAME=localhost' >> ~/.bashrc + echo 'export ROS_MASTER_URI=http://localhost:11311' >> ~/.bashrc + source ~/.bashrc + ``` +* If using **ROS across multiple computers**, follow these steps on each computer + ```bash + # your_local_ip = ip address of this computer in the network + echo 'export ROS_IP=' >> ~/.bashrc + # ros_masters_ip = ip address of the computer running roscore + echo 'export ROS_MASTER_URI=http://:11311' >> ~/.bashrc + source ~/.bashrc + ``` + +For assistance with finding the right ip of your computer, [please follow these steps](/installation/ISSUES.md#how-to-find-ip-address-of-local-computer). + +### Setup SPOT Robot +- Connect to robot's wifi, password for this wifi can be found in robot's belly after removing battery. +- Make sure that the robot is in access point mode (update to client mode in future). Refer to [this](https://support.bostondynamics.com/s/article/Spot-network-setup) page for information regarding Spot's network setup. + +```bash +echo 'export SPOT_ADMIN_PW=' >> ~/.bashrc +echo 'export SPOT_IP=' >> ~/.bashrc +source ~/.bashrc +``` + +### Testing the setup by running simple navigation policy on robot +1. Create waypoints.yaml file using the following command + ```bash + spot_rl_waypoint_recorder -x + ``` +2. Follow Steps 1,2,3,4 from [README.md](/README.md#running-the-demo-asclscseq-experts) +3. Go to root of repo, and run simple command to move robot to a new waypoint using the navigation policy. This command will move robot 2.5m in front after undocking. **Ensure there is 2.5m space in front of dock** + ```bash + python spot_rl_experiments/spot_rl/envs/nav_env.py -w "test_receptacle" + ``` +4. Once the robot has moved, you can dock back the robot with the following command + ```bash + spot_rl_autodock + ``` + +### For Meta internal users (with Meta account), please check the following link for the ip and the password + +[Link](https://docs.google.com/document/d/1u4x4ZMjHDQi33PB5V2aTZ3snUIV9UdkSOV1zFhRq1Do/edit) + +### Mac Users + +It is not recommended to run the code on a Mac machine, and we do not support this. However, it is possible to run the code on a Mac machine. Please reach out to Jimmy Yang (jimmytyyang@meta.com) for help. + + +### For folks who are interested to contribute to this repo, you'll need to setup pre-commit. +The repo runs CI tests on each PR and the PRs are merged only when the all checks have passed. +Installing the pre-commit allows you to run automatic pre-commit while running `git commit`. +```bash +pre-commit install +``` + +### Creating an account on CircleCI +- Since we have integrated CircleCI tests on this repo, you would need to create and link your CircleCI account +- You can create your account from this link (https://app.circleci.com/). Once you have created the account, go to "Organization Settings", on the left tab click on "VCS" +- Finally click on "Manage GitHub Checks". CircleCI will request access to `facebookresearch` org owner. diff --git a/installation/environment.yml b/installation/environment.yml new file mode 100644 index 00000000..8fc9e9ed --- /dev/null +++ b/installation/environment.yml @@ -0,0 +1,594 @@ +name: spot_ros +channels: + - pytorch + - robostack + - robostack-experimental + - conda-forge + - defaults +dependencies: + - _libgcc_mutex=0.1=conda_forge + - _openmp_mutex=4.5=1_llvm + - alsa-lib=1.2.3=h516909a_0 + - apr=1.7.0=h7f98852_5 + - assimp=5.0.1=hedfc422_7 + - atk-1.0=2.36.0=h3371d22_4 + - attrs=21.4.0=pyhd8ed1ab_0 + - bcrypt=3.2.0=py38h497a2fe_2 + - blas=2.114=mkl + - blas-devel=3.9.0=14_linux64_mkl + - boost=1.74.0=py38h2b96118_5 + - boost-cpp=1.74.0=h312852a_4 + - brotli=1.0.9=h166bdaf_7 + - brotli-bin=1.0.9=h166bdaf_7 + - brotlipy=0.7.0=py38h0a891b7_1004 + - bzip2=1.0.8=h7f98852_4 + - c-ares=1.18.1=h7f98852_0 + - cairo=1.16.0=h6cf1ce9_1008 + - catkin_pkg=0.4.24=pyhd8ed1ab_1 + - cffi=1.15.0=py38h3931269_0 + - charset-normalizer=2.0.12=pyhd8ed1ab_0 + - cmake=3.23.1=h5432695_0 + - colorama=0.4.4=pyh9f0ad1d_0 + - console_bridge=1.0.2=h924138e_1 + - cryptography=36.0.2=py38h2b5fc30_1 + - cudatoolkit=11.3.1=h2bc3f7f_2 + - cycler=0.11.0=pyhd8ed1ab_0 + - dbus=1.13.6=h5008d03_3 + - defusedxml=0.7.1=pyhd8ed1ab_0 + - distro=1.6.0=pyhd8ed1ab_0 + - docutils=0.18.1=py38h578d9bd_1 + - eigen=3.4.0=h4bd325d_0 + - empy=3.3.4=pyh9f0ad1d_1 + - expat=2.4.8=h27087fc_0 + - ffmpeg=4.3.2=h37c90e5_3 + - font-ttf-dejavu-sans-mono=2.37=hab24e00_0 + - font-ttf-inconsolata=3.000=h77eed37_0 + - font-ttf-source-code-pro=2.038=h77eed37_0 + - font-ttf-ubuntu=0.83=hab24e00_0 + - fontconfig=2.14.0=h8e229c2_0 + - fonts-conda-ecosystem=1=0 + - fonts-conda-forge=1=0 + - fonttools=4.32.0=py38h0a891b7_0 + - freeimage=3.18.0=hf18588b_8 + - freetype=2.10.4=h0708190_1 + - fribidi=1.0.10=h36c2ea0_0 + - gdk-pixbuf=2.42.8=hff1cb4f_0 + - gettext=0.19.8.1=h73d1719_1008 + - giflib=5.2.1=h36c2ea0_2 + - gitdb=4.0.9=pyhd8ed1ab_0 + - gitpython=3.1.27=pyhd8ed1ab_0 + - gmock=1.11.0=h924138e_0 + - gmp=6.2.1=h58526e2_0 + - gnutls=3.6.13=h85f3911_1 + - gpgme=1.15.1=h9c3ff4c_0 + - graphite2=1.3.13=h58526e2_1001 + - graphviz=2.50.0=h85b4f2f_1 + - gst-plugins-base=1.18.5=hf529b03_3 + - gstreamer=1.18.5=h9f60fe5_3 + - gtest=1.11.0=h924138e_0 + - gtk2=2.24.33=h539f30e_1 + - gts=0.7.6=h64030ff_2 + - harfbuzz=2.9.1=h83ec7ef_1 + - hdf5=1.10.6=nompi_h6a2412b_1114 + - icu=68.2=h9c3ff4c_0 + - idna=3.3=pyhd8ed1ab_0 + - imageio=2.17.0=pyhcf75d05_0 + - imageio-ffmpeg=0.4.7=pyhd8ed1ab_0 + - imath=3.1.5=h6239696_0 + - intel-openmp=2022.0.1=h06a4308_3633 + - jasper=1.900.1=h07fcdf6_1006 + - jbig=2.1=h7f98852_2003 + - jpeg=9e=h166bdaf_1 + - jxrlib=1.1=h7f98852_2 + - keyutils=1.6.1=h166bdaf_0 + - kiwisolver=1.4.2=py38h43d8883_1 + - krb5=1.19.3=h3790be6_0 + - lame=3.100=h7f98852_1001 + - lcms2=2.12=hddcbb42_0 + - ld_impl_linux-64=2.36.1=hea4e1c9_2 + - lerc=3.0=h9c3ff4c_0 + - libapr=1.7.0=h7f98852_5 + - libapriconv=1.2.2=h7f98852_5 + - libaprutil=1.6.1=h975c496_5 + - libassuan=2.5.5=h9c3ff4c_0 + - libblas=3.9.0=14_linux64_mkl + - libbrotlicommon=1.0.9=h166bdaf_7 + - libbrotlidec=1.0.9=h166bdaf_7 + - libbrotlienc=1.0.9=h166bdaf_7 + - libcblas=3.9.0=14_linux64_mkl + - libcurl=7.82.0=h7bff187_0 + - libdeflate=1.10=h7f98852_0 + - libedit=3.1.20191231=he28a2e2_2 + - libev=4.33=h516909a_1 + - libevent=2.1.10=h9b69904_4 + - libffi=3.4.2=h7f98852_5 + - libgcc-ng=11.2.0=h1d223b6_15 + - libgd=2.3.3=h6ad9fb6_0 + - libgfortran-ng=11.2.0=h69a702a_15 + - libgfortran5=11.2.0=h5c6108e_15 + - libglib=2.70.2=h174f98d_4 + - libgomp=11.2.0=h1d223b6_15 + - libgpg-error=1.44=h9eb791d_0 + - libiconv=1.16=h516909a_0 + - liblapack=3.9.0=14_linux64_mkl + - liblapacke=3.9.0=14_linux64_mkl + - libllvm10=10.0.1=he513fc3_3 + - libllvm11=11.1.0=hf817b99_3 + - libnghttp2=1.47.0=h727a467_0 + - libnsl=2.0.0=h7f98852_0 + - libogg=1.3.4=h7f98852_1 + - libopencv=4.5.1=py38h703c3c0_0 + - libopus=1.3.1=h7f98852_1 + - libpng=1.6.37=h21135ba_2 + - libpq=13.5=hd57d9b9_1 + - libprotobuf=3.19.4=h780b84a_0 + - libraw=0.20.2=h10796ff_1 + - librsvg=2.52.5=hc3c00ef_1 + - libsodium=1.0.18=h36c2ea0_1 + - libssh2=1.10.0=ha56f1ee_2 + - libstdcxx-ng=11.2.0=he4da1e4_15 + - libtiff=4.3.0=h542a066_3 + - libtool=2.4.6=h9c3ff4c_1008 + - libuuid=2.32.1=h7f98852_1000 + - libuv=1.43.0=h7f98852_0 + - libvorbis=1.3.7=h9c3ff4c_0 + - libwebp=1.2.2=h3452ae3_0 + - libwebp-base=1.2.2=h7f98852_1 + - libxcb=1.13=h7f98852_1004 + - libxkbcommon=1.0.3=he3ba5ed_0 + - libxml2=2.9.12=h72842e0_0 + - libzlib=1.2.11=h166bdaf_1014 + - llvm-openmp=13.0.1=he0ac6c6_1 + - log4cxx=0.11.0=h291d653_3 + - lz4=4.0.0=py38h1bf946c_1 + - lz4-c=1.9.3=h9c3ff4c_1 + - matplotlib=3.5.1=py38h578d9bd_0 + - matplotlib-base=3.5.1=py38hf4fb855_0 + - mkl=2022.0.1=h8d4b97c_803 + - mkl-devel=2022.0.1=ha770c72_804 + - mkl-include=2022.0.1=h8d4b97c_803 + - munkres=1.1.4=pyh9f0ad1d_0 + - mysql-common=8.0.28=haf5c9bc_4 + - mysql-libs=8.0.28=h28c427c_4 + - ncurses=6.3=h27087fc_1 + - netifaces=0.10.9=py38h497a2fe_1005 + - nettle=3.6=he412f7d_0 + - ninja=1.10.2=h4bd325d_1 + - nose=1.3.7=py_1006 + - nspr=4.32=h9c3ff4c_1 + - nss=3.77=h2350873_0 + - ogre=1.10.12=h7cc4a1d_8 + - openexr=3.1.5=he0ac6c6_0 + - openh264=2.1.1=h780b84a_0 + - openjpeg=2.4.0=hb52868f_1 + - orocos-kdl=1.4.0=h9c3ff4c_0 + - packaging=21.3=pyhd8ed1ab_0 + - pango=1.48.10=hb8ff022_1 + - paramiko=2.10.3=pyhd8ed1ab_0 + - pcre=8.45=h9c3ff4c_0 + - pip=22.0.4=pyhd8ed1ab_0 + - pixman=0.40.0=h36c2ea0_0 + - pkg-config=0.29.2=h36c2ea0_1008 + - poco=1.10.1=h9c89518_1 + - psutil=5.9.0=py38h0a891b7_1 + - pthread-stubs=0.4=h36c2ea0_1001 + - pugixml=1.11.4=h9c3ff4c_0 + - py-opencv=4.5.1=py38h81c977d_0 + - pycairo=1.21.0=py38h9c00e7a_1 + - pycparser=2.21=pyhd8ed1ab_0 + - pycryptodome=3.14.1=py38h0757ffc_1 + - pycryptodomex=3.14.1=py38h0a891b7_1 + - pydot=1.4.2=py38h578d9bd_1 + - pynacl=1.5.0=py38h0a891b7_1 + - pyopengl=3.1.6=pyh6c4a22f_0 + - pyopenssl=22.0.0=pyhd8ed1ab_0 + - pyparsing=3.0.8=pyhd8ed1ab_0 + - pyqt=5.12.3=py38h578d9bd_8 + - pyqt-impl=5.12.3=py38h0ffb2e6_8 + - pyqt5-sip=4.19.18=py38h709712a_8 + - pyqtchart=5.12=py38h7400c14_8 + - pyqtwebengine=5.12.1=py38h7400c14_8 + - pysocks=1.7.1=py38h578d9bd_5 + - python=3.8.13=h582c2e5_0_cpython + - python-dateutil=2.8.2=pyhd8ed1ab_0 + - python-gnupg=0.4.8=pyhd8ed1ab_0 + - python-orocos-kdl=1.4.0=py38h709712a_0 + - python_abi=3.8=2_cp38 + - pyyaml=6.0=py38h0a891b7_4 + - qt=5.12.9=hda022c4_4 + - quaternion=2022.4.1=py38h71d37f0_3 + - readline=8.1=h46c0cb4_0 + - requests=2.27.1=pyhd8ed1ab_0 + - rhash=1.4.1=h7f98852_0 + - ros-distro-mutex=0.1.0=noetic + - ros-noetic-actionlib=1.13.2=py38h14b2acc_6 + - ros-noetic-actionlib-msgs=1.13.1=py38he9ab703_6 + - ros-noetic-actionlib-tutorials=0.2.0=py38he9ab703_6 + - ros-noetic-angles=1.9.13=py38he9ab703_6 + - ros-noetic-bond=1.8.6=py38he9ab703_6 + - ros-noetic-bond-core=1.8.6=py38he9ab703_6 + - ros-noetic-bondcpp=1.8.6=py38hbce7a70_6 + - ros-noetic-bondpy=1.8.6=py38he9ab703_6 + - ros-noetic-catkin=0.8.10=py38he9ab703_9 + - ros-noetic-class-loader=0.5.0=py38h6655857_6 + - ros-noetic-cmake-modules=0.5.0=py38he9ab703_6 + - ros-noetic-common-msgs=1.13.1=py38he9ab703_6 + - ros-noetic-common-tutorials=0.2.0=py38he9ab703_6 + - ros-noetic-control-msgs=1.5.2=py38he9ab703_6 + - ros-noetic-cpp-common=0.7.2=py38haa43186_6 + - ros-noetic-cv-bridge=1.15.0=py38heb39ad0_6 + - ros-noetic-desktop=1.5.0=py38he9ab703_6 + - ros-noetic-diagnostic-aggregator=1.10.3=py38he9ab703_6 + - ros-noetic-diagnostic-analysis=1.10.3=py38he9ab703_6 + - ros-noetic-diagnostic-common-diagnostics=1.10.3=py38he9ab703_6 + - ros-noetic-diagnostic-msgs=1.13.1=py38he9ab703_6 + - ros-noetic-diagnostic-updater=1.10.3=py38he9ab703_6 + - ros-noetic-diagnostics=1.10.3=py38he9ab703_6 + - ros-noetic-dynamic-reconfigure=1.7.1=py38h14b2acc_6 + - ros-noetic-eigen-conversions=1.13.2=py38he9ab703_6 + - ros-noetic-executive-smach=2.5.0=py38he9ab703_6 + - ros-noetic-filters=1.9.1=py38h14b2acc_6 + - ros-noetic-gencpp=0.6.5=py38he9ab703_6 + - ros-noetic-geneus=3.0.0=py38he9ab703_6 + - ros-noetic-genlisp=0.4.18=py38he9ab703_6 + - ros-noetic-genmsg=0.5.16=py38he9ab703_6 + - ros-noetic-gennodejs=2.0.2=py38he9ab703_6 + - ros-noetic-genpy=0.6.14=py38he9ab703_6 + - ros-noetic-geometry=1.13.2=py38he9ab703_6 + - ros-noetic-geometry-msgs=1.13.1=py38he9ab703_6 + - ros-noetic-geometry-tutorials=0.2.3=py38he9ab703_6 + - ros-noetic-gl-dependency=1.1.2=py38he9ab703_6 + - ros-noetic-image-transport=1.12.0=py38he9ab703_6 + - ros-noetic-interactive-marker-tutorials=0.11.0=py38he9ab703_6 + - ros-noetic-interactive-markers=1.12.0=py38he9ab703_6 + - ros-noetic-joint-state-publisher=1.15.0=py38he9ab703_6 + - ros-noetic-joint-state-publisher-gui=1.15.0=py38he9ab703_6 + - ros-noetic-kdl-conversions=1.13.2=py38he9ab703_6 + - ros-noetic-kdl-parser=1.14.1=py38he9ab703_6 + - ros-noetic-laser-geometry=1.6.7=py38h266acf0_6 + - ros-noetic-librviz-tutorial=0.11.0=py38h7d895da_6 + - ros-noetic-map-msgs=1.14.1=py38he9ab703_6 + - ros-noetic-media-export=0.3.0=py38he9ab703_6 + - ros-noetic-message-filters=1.15.9=py38h14b2acc_6 + - ros-noetic-message-generation=0.4.1=py38he9ab703_6 + - ros-noetic-message-runtime=0.4.13=py38he9ab703_6 + - ros-noetic-mk=1.15.7=py38he9ab703_6 + - ros-noetic-nav-msgs=1.13.1=py38he9ab703_6 + - ros-noetic-nodelet=1.10.1=py38h343f6d5_6 + - ros-noetic-nodelet-core=1.10.1=py38he9ab703_6 + - ros-noetic-nodelet-topic-tools=1.10.1=py38h266acf0_6 + - ros-noetic-nodelet-tutorial-math=0.2.0=py38he9ab703_6 + - ros-noetic-pluginlib=1.13.0=py38h14b2acc_6 + - ros-noetic-pluginlib-tutorials=0.2.0=py38he9ab703_6 + - ros-noetic-python-qt-binding=0.4.3=py38h7d895da_6 + - ros-noetic-qt-dotgraph=0.4.2=py38he9ab703_6 + - ros-noetic-qt-gui=0.4.2=py38h7d895da_6 + - ros-noetic-qt-gui-cpp=0.4.2=py38h7d895da_6 + - ros-noetic-qt-gui-py-common=0.4.2=py38he9ab703_6 + - ros-noetic-qwt-dependency=1.1.1=py38he9ab703_6 + - ros-noetic-resource-retriever=1.12.6=py38h91cbfbb_6 + - ros-noetic-robot=1.5.0=py38he9ab703_6 + - ros-noetic-robot-state-publisher=1.15.0=py38he9ab703_6 + - ros-noetic-ros=1.15.7=py38he9ab703_6 + - ros-noetic-ros-base=1.5.0=py38he9ab703_6 + - ros-noetic-ros-comm=1.15.9=py38he9ab703_6 + - ros-noetic-ros-core=1.5.0=py38he9ab703_6 + - ros-noetic-ros-environment=1.3.2=py38he9ab703_6 + - ros-noetic-ros-tutorials=0.10.2=py38he9ab703_6 + - ros-noetic-rosbag=1.15.9=py38h14b2acc_6 + - ros-noetic-rosbag-migration-rule=1.0.1=py38he9ab703_6 + - ros-noetic-rosbag-storage=1.15.9=py38hda0ab37_6 + - ros-noetic-rosbash=1.15.7=py38he9ab703_6 + - ros-noetic-rosboost-cfg=1.15.7=py38he9ab703_6 + - ros-noetic-rosbuild=1.15.7=py38he9ab703_6 + - ros-noetic-rosclean=1.15.7=py38he9ab703_6 + - ros-noetic-rosconsole=1.14.3=py38hb51d2b2_6 + - ros-noetic-rosconsole-bridge=0.5.4=py38hcda8483_6 + - ros-noetic-roscpp=1.15.9=py38h14b2acc_6 + - ros-noetic-roscpp-core=0.7.2=py38he9ab703_6 + - ros-noetic-roscpp-serialization=0.7.2=py38he9ab703_6 + - ros-noetic-roscpp-traits=0.7.2=py38he9ab703_6 + - ros-noetic-roscpp-tutorials=0.10.2=py38h14b2acc_6 + - ros-noetic-roscreate=1.15.7=py38he9ab703_6 + - ros-noetic-rosgraph=1.15.9=py38he9ab703_6 + - ros-noetic-rosgraph-msgs=1.11.3=py38he9ab703_6 + - ros-noetic-roslang=1.15.7=py38he9ab703_6 + - ros-noetic-roslaunch=1.15.9=py38he9ab703_6 + - ros-noetic-roslib=1.15.7=py38h14b2acc_6 + - ros-noetic-roslint=0.12.0=py38he9ab703_6 + - ros-noetic-roslisp=1.9.24=py38he9ab703_6 + - ros-noetic-roslz4=1.15.9=py38he9ab703_6 + - ros-noetic-rosmake=1.15.7=py38he9ab703_6 + - ros-noetic-rosmaster=1.15.9=py38he9ab703_6 + - ros-noetic-rosmsg=1.15.9=py38he9ab703_6 + - ros-noetic-rosnode=1.15.9=py38he9ab703_6 + - ros-noetic-rosout=1.15.9=py38he9ab703_6 + - ros-noetic-rospack=2.6.2=py38h14b2acc_6 + - ros-noetic-rosparam=1.15.9=py38he9ab703_6 + - ros-noetic-rospy=1.15.9=py38he9ab703_6 + - ros-noetic-rospy-tutorials=0.10.2=py38he9ab703_6 + - ros-noetic-rosservice=1.15.9=py38he9ab703_6 + - ros-noetic-rostest=1.15.9=py38h14b2acc_6 + - ros-noetic-rostime=0.7.2=py38h14b2acc_6 + - ros-noetic-rostopic=1.15.9=py38he9ab703_6 + - ros-noetic-rosunit=1.15.7=py38he9ab703_6 + - ros-noetic-roswtf=1.15.9=py38he9ab703_6 + - ros-noetic-rqt-action=0.4.9=py38he9ab703_6 + - ros-noetic-rqt-bag=0.5.1=py38he9ab703_6 + - ros-noetic-rqt-bag-plugins=0.5.1=py38he9ab703_6 + - ros-noetic-rqt-common-plugins=0.4.9=py38he9ab703_6 + - ros-noetic-rqt-console=0.4.11=py38he9ab703_6 + - ros-noetic-rqt-dep=0.4.10=py38he9ab703_6 + - ros-noetic-rqt-graph=0.4.14=py38he9ab703_6 + - ros-noetic-rqt-gui=0.5.2=py38he9ab703_6 + - ros-noetic-rqt-gui-cpp=0.5.2=py38h7d895da_6 + - ros-noetic-rqt-gui-py=0.5.2=py38he9ab703_6 + - ros-noetic-rqt-image-view=0.4.16=py38h7d895da_11 + - ros-noetic-rqt-launch=0.4.9=py38he9ab703_6 + - ros-noetic-rqt-logger-level=0.4.11=py38he9ab703_6 + - ros-noetic-rqt-moveit=0.5.9=py38he9ab703_6 + - ros-noetic-rqt-msg=0.4.9=py38he9ab703_6 + - ros-noetic-rqt-nav-view=0.5.7=py38he9ab703_6 + - ros-noetic-rqt-plot=0.4.13=py38he9ab703_6 + - ros-noetic-rqt-pose-view=0.5.10=py38he9ab703_6 + - ros-noetic-rqt-publisher=0.4.9=py38he9ab703_6 + - ros-noetic-rqt-py-common=0.5.2=py38he9ab703_6 + - ros-noetic-rqt-py-console=0.4.9=py38he9ab703_6 + - ros-noetic-rqt-reconfigure=0.5.3=py38he9ab703_6 + - ros-noetic-rqt-robot-dashboard=0.5.8=py38he9ab703_6 + - ros-noetic-rqt-robot-monitor=0.5.13=py38he9ab703_6 + - ros-noetic-rqt-robot-plugins=0.5.8=py38he9ab703_6 + - ros-noetic-rqt-robot-steering=0.5.12=py38he9ab703_6 + - ros-noetic-rqt-runtime-monitor=0.5.8=py38he9ab703_6 + - ros-noetic-rqt-rviz=0.6.1=py38h4fa06b8_6 + - ros-noetic-rqt-service-caller=0.4.9=py38he9ab703_6 + - ros-noetic-rqt-shell=0.4.10=py38he9ab703_6 + - ros-noetic-rqt-srv=0.4.8=py38he9ab703_6 + - ros-noetic-rqt-tf-tree=0.6.2=py38he9ab703_6 + - ros-noetic-rqt-top=0.4.9=py38he9ab703_6 + - ros-noetic-rqt-topic=0.4.12=py38he9ab703_6 + - ros-noetic-rqt-web=0.4.9=py38he9ab703_6 + - ros-noetic-rviz=1.14.5=py38h7d895da_6 + - ros-noetic-rviz-plugin-tutorials=0.11.0=py38h7d895da_6 + - ros-noetic-rviz-python-tutorial=0.11.0=py38he9ab703_6 + - ros-noetic-self-test=1.10.3=py38he9ab703_6 + - ros-noetic-sensor-msgs=1.13.1=py38he9ab703_6 + - ros-noetic-shape-msgs=1.13.1=py38he9ab703_6 + - ros-noetic-smach=2.5.0=py38he9ab703_6 + - ros-noetic-smach-msgs=2.5.0=py38he9ab703_6 + - ros-noetic-smach-ros=2.5.0=py38he9ab703_6 + - ros-noetic-smclib=1.8.6=py38he9ab703_6 + - ros-noetic-std-msgs=0.5.13=py38he9ab703_6 + - ros-noetic-std-srvs=1.11.3=py38he9ab703_6 + - ros-noetic-stereo-msgs=1.13.1=py38he9ab703_6 + - ros-noetic-tf=1.13.2=py38h5038f88_6 + - ros-noetic-tf-conversions=1.13.2=py38he9ab703_6 + - ros-noetic-tf2=0.7.5=py38hcda8483_6 + - ros-noetic-tf2-geometry-msgs=0.7.5=py38he9ab703_6 + - ros-noetic-tf2-kdl=0.7.5=py38he9ab703_6 + - ros-noetic-tf2-msgs=0.7.5=py38he9ab703_6 + - ros-noetic-tf2-py=0.7.5=py38he9ab703_6 + - ros-noetic-tf2-ros=0.7.5=py38he9ab703_6 + - ros-noetic-topic-tools=1.15.9=py38he9ab703_6 + - ros-noetic-trajectory-msgs=1.13.1=py38he9ab703_6 + - ros-noetic-turtle-actionlib=0.2.0=py38he9ab703_6 + - ros-noetic-turtle-tf=0.2.3=py38he9ab703_6 + - ros-noetic-turtle-tf2=0.2.3=py38he9ab703_6 + - ros-noetic-turtlesim=0.10.2=py38h5258f4b_6 + - ros-noetic-urdf=1.13.2=py38he9ab703_6 + - ros-noetic-urdf-parser-plugin=1.13.2=py38he9ab703_6 + - ros-noetic-urdf-tutorial=0.5.0=py38he9ab703_6 + - ros-noetic-visualization-marker-tutorials=0.11.0=py38he9ab703_6 + - ros-noetic-visualization-msgs=1.13.1=py38he9ab703_6 + - ros-noetic-visualization-tutorials=0.11.0=py38he9ab703_6 + - ros-noetic-viz=1.5.0=py38he9ab703_6 + - ros-noetic-webkit-dependency=1.1.2=py38he9ab703_6 + - ros-noetic-xacro=1.14.6=py38he9ab703_6 + - ros-noetic-xmlrpcpp=1.15.9=py38h14b2acc_6 + - rosdep=0.21.0=pyhd8ed1ab_1 + - rosdistro=0.8.3=py38h578d9bd_3 + - rospkg=1.4.0=pyhd8ed1ab_0 + - sbcl=1.5.4=ha770c72_1 + - scipy=1.8.0=py38h56a6a73_1 + - sdl2=2.0.18=h27087fc_0 + - setuptools=62.1.0=py38h578d9bd_0 + - sip=6.5.1=py38h709712a_2 + - six=1.16.0=pyh6c4a22f_0 + - sleef=3.5.1=h9b69904_2 + - smmap=3.0.5=pyh44b312d_0 + - sqlite=3.38.2=h4ff8645_0 + - swig=4.0.2=hd3c618e_2 + - tbb=2021.5.0=h924138e_1 + - tinyxml=2.6.2=h4bd325d_2 + - tinyxml2=8.0.0=h9c3ff4c_1 + - tk=8.6.12=h27826a3_0 + - toml=0.10.2=pyhd8ed1ab_0 + - tornado=6.1=py38h0a891b7_3 + - tqdm=4.64.0=pyhd8ed1ab_0 + - typing_extensions=4.2.0=pyha770c72_1 + - unicodedata2=14.0.0=py38h0a891b7_1 + - unixodbc=2.3.9=hb166930_0 + - urdfdom=3.0.2=h27087fc_2 + - urdfdom_headers=1.0.6=h924138e_2 + - urllib3=1.26.9=pyhd8ed1ab_0 + - wheel=0.37.1=pyhd8ed1ab_0 + - x264=1!161.3030=h7f98852_1 + - xorg-fixesproto=5.0=h7f98852_1002 + - xorg-inputproto=2.3.2=h7f98852_1002 + - xorg-kbproto=1.0.7=h7f98852_1002 + - xorg-libice=1.0.10=h7f98852_0 + - xorg-libsm=1.2.3=hd9c2040_1000 + - xorg-libx11=1.7.2=h7f98852_0 + - xorg-libxau=1.0.9=h7f98852_0 + - xorg-libxaw=1.0.14=h7f98852_1 + - xorg-libxcursor=1.2.0=h7f98852_0 + - xorg-libxdmcp=1.1.3=h7f98852_0 + - xorg-libxext=1.3.4=h7f98852_1 + - xorg-libxfixes=5.0.3=h7f98852_1004 + - xorg-libxi=1.7.10=h7f98852_0 + - xorg-libxinerama=1.1.4=h9c3ff4c_1001 + - xorg-libxmu=1.1.3=h7f98852_0 + - xorg-libxpm=3.5.13=h7f98852_0 + - xorg-libxrandr=1.5.2=h7f98852_1 + - xorg-libxrender=0.9.10=h7f98852_1003 + - xorg-libxt=1.2.1=h7f98852_2 + - xorg-randrproto=1.5.0=h7f98852_1001 + - xorg-renderproto=0.11.1=h7f98852_1002 + - xorg-xextproto=7.3.0=h7f98852_1002 + - xorg-xproto=7.0.31=h7f98852_1007 + - xz=5.2.5=h516909a_1 + - yaml=0.2.5=h7f98852_2 + - yaml-cpp=0.6.3=he1b5a44_4 + - zlib=1.2.11=h166bdaf_1014 + - zstd=1.5.2=ha95c52a_0 + - zziplib=0.13.69=h27826a3_1 + - pip: + - absl-py==1.0.0 + - aiohttp==3.8.1 + - aiosignal==1.2.0 + - albumentations==1.1.0 + - antlr4-python3-runtime==4.8 + - appdirs==1.4.4 + - args==0.1.0 + - asttokens==2.0.5 + - astunparse==1.6.3 + - async-timeout==4.0.2 + - backcall==0.2.0 + - beautifulsoup4==4.11.1 + - black==21.4b2 + - blosc==1.10.6 + - bosdyn-api==3.2.3 + - bosdyn-choreography-client==3.2.3 + - bosdyn-choreography-protos==3.2.3 + - bosdyn-client==3.2.3 + - bosdyn-core==3.2.3 + - bosdyn-mission==3.2.3 + - braceexpand==0.1.7 + - cachetools==5.0.0 + - click==8.1.2 + - clint==0.5.1 + - cloudpickle==2.0.0 + - coverage==6.3.2 + - datasets==2.10.1 + - decorator==4.4.2 + - deprecated==1.2.13 + - dill==0.3.6 + - executing==0.8.3 + - fastdtw==0.3.4 + - filelock==3.6.0 + - fire==0.4.0 + - flatbuffers==1.12 + - frozenlist==1.3.0 + - fsspec==2023.3.0 + - future==0.18.2 + - fvcore==0.1.5.post20220414 + - gast==0.4.0 + - gdown==4.4.0 + - glog==0.3.1 + - google-auth==1.6.3 + - google-auth-oauthlib==0.4.6 + - google-pasta==0.2.0 + - grpcio==1.44.0 + - gym==0.23.1 + - gym-notices==0.0.6 + - h5py==3.6.0 + - huggingface-hub==0.13.2 + - hydra-core==1.1.2 + - ifcfg==0.22 + - importlib-metadata==4.11.3 + - importlib-resources==5.2.3 + - install==1.3.5 + - iopath==0.1.9 + - ipdb==0.13.11 + - ipython==8.2.0 + - jedi==0.18.1 + - joblib==1.1.0 + - keras==2.9.0rc1 + - keras-preprocessing==1.1.2 + - libclang==13.0.0 + - llvmlite==0.38.0 + - lmdb==1.3.0 + - markdown==3.3.6 + - matplotlib-inline==0.1.3 + - moviepy==2.0.0.dev2 + - msgpack==1.0.3 + - multidict==6.0.2 + - multiprocess==0.70.14 + - munch==2.5.0 + - mypy-extensions==0.4.3 + - networkx==2.8 + - numba==0.55.1 + - numpy==1.21.6 + - oauthlib==3.2.0 + - objectio==0.2.29 + - omegaconf==2.1.2 + - openai==0.27.4 + - opencv-python==4.5.5.64 + - opt-einsum==3.3.0 + - pandas==1.5.3 + - parso==0.8.3 + - pathspec==0.9.0 + - pexpect==4.8.0 + - pickleshare==0.7.5 + - pillow==6.2.2 + - portalocker==2.4.0 + - pre-commit==3.1.1 + - pretrainedmodels==0.7.4 + - prompt-toolkit==3.0.29 + - protobuf==3.20.3 + - ptyprocess==0.7.0 + - pure-eval==0.2.2 + - pyarrow==11.0.0 + - pyasn1==0.4.8 + - pyasn1-modules==0.2.8 + - pybullet==3.2.2 + - pygments==2.11.2 + - pyjwt==2.3.0 + - python-gflags==3.1.2 + - pytz==2022.7.1 + - pywavelets==1.3.0 + - qudida==0.0.4 + - regex==2022.3.15 + - requests-oauthlib==1.3.1 + - responses==0.18.0 + - rsa==4.8 + - ruamel-yaml==0.17.21 + - ruamel-yaml-clib==0.2.6 + - scikit-image==0.18.1 + - scikit-learn==1.0.2 + - simplejson==3.17.6 + - sounddevice==0.4.6 + - soundfile==0.12.1 + - soupsieve==2.3.2.post1 + - stack-data==0.2.0 + - tabulate==0.8.9 + - tb-nightly==2.9.0a20220420 + - tensorboard==2.8.0 + - tensorboard-data-server==0.6.1 + - tensorboard-plugin-wit==1.8.1 + - tensorboardx==2.5 + - tensorflow-estimator==2.9.0rc0 + - tensorflow-io-gcs-filesystem==0.24.0 + - termcolor==1.1.0 + - threadpoolctl==3.1.0 + - tifffile==2022.4.8 + - tokenizers==0.13.2 + - tomli==2.0.1 + - torchsummary==1.5.1 + - traitlets==5.1.1 + - transformers==4.26.1 + - typer==0.4.1 + - typing-extensions==3.7.4.3 + - vtk==9.1.0 + - wcwidth==0.2.5 + - webdataset==0.1.40 + - webrtcvad-wheels==2.0.11.post1 + - werkzeug==2.1.1 + - whisper==1.1.10 + - wrapt==1.14.0 + - wslink==1.6.6 + - xxhash==3.2.0 + - yacs==0.1.8 + - yarl==1.8.0 + - zipp==3.8.0 +prefix: ~/miniconda3/envs/spot_ros diff --git a/spot_rl_experiments/.gitignore b/spot_rl_experiments/.gitignore new file mode 100644 index 00000000..6633eccb --- /dev/null +++ b/spot_rl_experiments/.gitignore @@ -0,0 +1,7 @@ +*.pth +__pycache__ +bd_spot_wrapper +*.DS_Store +weights +notes +configs/waypoints.yaml diff --git a/spot_rl_experiments/README.md b/spot_rl_experiments/README.md new file mode 100644 index 00000000..f945941d --- /dev/null +++ b/spot_rl_experiments/README.md @@ -0,0 +1,13 @@ +# spot_rl_experiments + +## Installation + +Install requirements +```bash +pip install -r requirements.txt +``` +Install this package +```bash +# Make sure you are in the root of this repo +pip install -e . +``` diff --git a/spot_rl_experiments/__init__.py b/spot_rl_experiments/__init__.py new file mode 100644 index 00000000..5ce8caf5 --- /dev/null +++ b/spot_rl_experiments/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/spot_rl_experiments/configs/config.yaml b/spot_rl_experiments/configs/config.yaml new file mode 100644 index 00000000..bec53173 --- /dev/null +++ b/spot_rl_experiments/configs/config.yaml @@ -0,0 +1,70 @@ +WEIGHTS: + NAV: "weights/CUTOUT_WT_True_SD_200_ckpt.99.pvp.pth" + # GAZE: "weights/gaze_normal_32_seed300_1649708927_ckpt.49.pth" + # PLACE: "weights/speed_sweep_seed2_speed0.174533_1648669786.ckpt.94.pth" + # MIXER: "weights/sweep_17_SD_200_1651156250_ckpt.12.pth" +# new + # NAV: "weights/final_paper/nav_CUTOUT_WT_True_SD_300_ckpt.90.pth" + GAZE: "weights/final_paper/gaze_normal_32_seed100_1649708902_ckpt.38.pth" + PLACE: "weights/final_paper/place_10deg_32_seed300_1649709235_ckpt.75.pth" + MIXER: "weights/final_paper/final_moe_rnn_60_1.0_SD_100_1652120928_ckpt.16_copy.pth" + # MIXER: "weights/final_gater_rnn_60_1.0_SD_300_1652203350_ckpt_55.pth" + MRCNN: "weights/ikea_apricot_large_only_model_0002999.pth" + MRCNN_50: "weights/ikea_apricot_r50_normal_100_output_model_0003599.pth" + DEBLURGAN: "weights/fpn_inception.h5" + +DEVICE: "cuda:0" +USE_REMOTE_SPOT: False +PARALLEL_INFERENCE_MODE: True + +# General env params +CTRL_HZ: 2.0 +MAX_EPISODE_STEPS: 500 + +# Nav env +SUCCESS_DISTANCE: 0.3 +SUCCESS_ANGLE_DIST: 5 +DISABLE_OBSTACLE_AVOIDANCE: True +USE_OA_FOR_NAV: True +USE_HEAD_CAMERA: True + +# Gaze env +CENTER_TOLERANCE: 0.3 +OBJECT_LOCK_ON_NEEDED: 3 +TARGET_OBJ_NAME: "rubiks_cube" +DONT_PICK_UP: False +ASSERT_CENTERING: True + +# Place env +EE_GRIPPER_OFFSET: [0.2, 0.0, 0.05] +SUCC_XY_DIST: 0.1 +SUCC_Z_DIST: 0.20 + +# Base action params +MAX_LIN_DIST: 0.25 # meters +MAX_ANG_DIST: 15.0 # degrees + +# Arm action params +MAX_JOINT_MOVEMENT: 0.0872665 # Gaze arm speed (5 deg) +MAX_JOINT_MOVEMENT_2: 0.174533 # Place arm speed (6 deg) +INITIAL_ARM_JOINT_ANGLES: [0, -170, 120, 0, 75, 0] +GAZE_ARM_JOINT_ANGLES: [0, -160, 100, 0, 90, 0] +PLACE_ARM_JOINT_ANGLES: [0, -170, 120, 0, 75, 0] +ARM_LOWER_LIMITS: [-45, -180, 0, 0, -90, 0] +ARM_UPPER_LIMITS: [45, -45, 180, 0, 90, 0] +JOINT_BLACKLIST: [3, 5] # joints we can't control "arm0.el0", "arm0.wr1" +ACTUALLY_MOVE_ARM: True +GRASP_EVERY_STEP: False +TERMINATE_ON_GRASP: False + +# Mask RCNN +GRAYSCALE_MASK_RCNN: False +USE_MRCNN: True +USE_FPN_R50: False +USE_DEBLURGAN: True +IMAGE_SCALE: 0.7 +# After this many time steps of not seeing the current target object, we become open to looking for new ones +FORGET_TARGET_OBJECT_STEPS: 15 + +# Docking (currently only used by ASC, Seq Exp and Language env) +RETURN_TO_BASE: True diff --git a/spot_rl_experiments/configs/ros_topic_names.yaml b/spot_rl_experiments/configs/ros_topic_names.yaml new file mode 100644 index 00000000..1354290c --- /dev/null +++ b/spot_rl_experiments/configs/ros_topic_names.yaml @@ -0,0 +1,22 @@ +HEAD_DEPTH: "/raw_head_depth" +HAND_DEPTH: "/raw_hand_depth" +HAND_RGB: "/hand_rgb" + +COMPRESSED_IMAGES: "/compressed_images" + +FILTERED_HEAD_DEPTH: "/filtered_head_depth" +FILTERED_HAND_DEPTH: "/filtered_hand_depth" + +MASK_RCNN_VIZ_TOPIC: "/mask_rcnn_visualizations" +DETECTIONS_TOPIC: "/mask_rcnn_detections" +IMAGE_SCALE: "/image_scale" + +ROBOT_STATE: "/robot_state" +TEXT_TO_SPEECH: "/text_to_speech" + + +# Remote robot topics +ROBOT_CMD_TOPIC: "/remote_robot_cmd" +CMD_ENDED_TOPIC: "/remote_robot_cmd_ended" +INIT_REMOTE_ROBOT: "/init_remote_robot" +KILL_REMOTE_ROBOT: "/kill_remote_robot" diff --git a/spot_rl_experiments/experiments/comparisons/gaze_all_objects.py b/spot_rl_experiments/experiments/comparisons/gaze_all_objects.py new file mode 100644 index 00000000..7300c6bd --- /dev/null +++ b/spot_rl_experiments/experiments/comparisons/gaze_all_objects.py @@ -0,0 +1,46 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import argparse +import time + +from spot_rl.envs.gaze_env import run_env +from spot_rl.utils.utils import construct_config +from spot_wrapper.spot import Spot +from spot_wrapper.utils import say + +names = ["ball", "penguin", "rubiks_cube", "lion", "toy_car", "yellow_truck"] + + +def main(spot, bd=False): + config = construct_config() + config.DONT_PICK_UP = True + config.OBJECT_LOCK_ON_NEEDED = 5 + # config.CTRL_HZ = 2.0 + config.TERMINATE_ON_GRASP = True + config.FORGET_TARGET_OBJECT_STEPS = 1000000 + if bd: + config.GRASP_EVERY_STEP = True + config.MAX_JOINT_MOVEMENT = 0.0 # freeze arm + config.MAX_EPISODE_STEPS = 20 + else: + config.MAX_EPISODE_STEPS = 150 + orig_pos = None + for _ in range(3): + for name in names: + say("Targeting " + name) + time.sleep(2) + orig_pos = run_env(spot, config, target_obj_id=name, orig_pos=orig_pos) + say("Episode over") + time.sleep(2) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--bd", action="store_true") + args = parser.parse_args() + spot = Spot("RealGazeEnv") + with spot.get_lease(hijack=True): + main(spot, bd=args.bd) diff --git a/spot_rl_experiments/experiments/comparisons/multiple_gaze.py b/spot_rl_experiments/experiments/comparisons/multiple_gaze.py new file mode 100644 index 00000000..d8ec03f0 --- /dev/null +++ b/spot_rl_experiments/experiments/comparisons/multiple_gaze.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import time + +from spot_rl.envs.gaze_env import SpotGazeEnv +from spot_rl.real_policy import GazePolicy +from spot_rl.utils.utils import ( + construct_config, + get_default_parser, + nav_target_from_waypoints, + object_id_to_object_name, +) +from spot_wrapper.spot import Spot + + +def main(spot): + parser = get_default_parser() + args = parser.parse_args() + config = construct_config(args.opts) + + env = SpotGazeEnv(config, spot, mask_rcnn_weights=config.WEIGHTS.MRCNN) + env.power_robot() + policy = GazePolicy(config.WEIGHTS.GAZE, device=config.DEVICE) + for target_id in range(1, 9): + goal_x, goal_y, goal_heading = nav_target_from_waypoints("white_box") + spot.set_base_position( + x_pos=goal_x, y_pos=goal_y, yaw=goal_heading, end_time=100, blocking=True + ) + time.sleep(4) + policy.reset() + observations = env.reset(target_obj_id=target_id) + done = False + env.say("Looking for", object_id_to_object_name(target_id)) + while not done: + action = policy.act(observations) + observations, _, done, _ = env.step(arm_action=action) + + +if __name__ == "__main__": + spot = Spot("MultipleGazeEnv") + with spot.get_lease(hijack=True): + try: + main(spot) + finally: + spot.power_off() diff --git a/spot_rl_experiments/experiments/comparisons/nav_compare.py b/spot_rl_experiments/experiments/comparisons/nav_compare.py new file mode 100644 index 00000000..c73f2f54 --- /dev/null +++ b/spot_rl_experiments/experiments/comparisons/nav_compare.py @@ -0,0 +1,154 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import argparse +import time +from collections import defaultdict + +import numpy as np +from spot_rl.envs.nav_env import SpotNavEnv +from spot_rl.real_policy import NavPolicy +from spot_rl.utils.utils import construct_config +from spot_wrapper.spot import Spot, wrap_heading +from spot_wrapper.utils import say + +ROUTES = [ + ((6.0, 2.0, 0.0), (6.0, -1.0, 0.0)), + ((3.5, 1.5, np.pi), (-0.25, 1.5, np.pi)), + ((1.0, 0.0, 0.0), (8, -5.18, -np.pi / 2)), +] + + +def main(spot, idx): + config = construct_config([]) + policy = NavPolicy(config.WEIGHTS.NAV, device=config.DEVICE) + + env = SpotNavEnv(config, spot) + env.power_robot() + + start_waypoint, goal_waypoint = ROUTES[idx] + + return_to_start(spot, start_waypoint, policy, env) + time.sleep(2) + datas = defaultdict(list) + times = defaultdict(list) + + for ctrl_idx, nav_func in enumerate([learned_navigate, baseline_navigate]): + for _ in range(3): + say("Starting episode.") + time.sleep(3) + st = time.time() + traj = nav_func(spot=spot, waypoint=goal_waypoint, policy=policy, env=env) + traj_time = time.time() - st + datas[ctrl_idx].append((traj, traj_time)) + times[ctrl_idx].append(traj_time) + spot.set_base_velocity(0, 0, 0, 1) + say("Done with episode. Returning.") + time.sleep(3) + print("Returning...") + if idx == 2: + return_to_start(spot, (8.0, -1.0, np.pi), policy, env, no_learn=True) + return_to_start(spot, start_waypoint, policy, env) + print("Done returning.") + + for k, v in times.items(): + name = ["Learned", "BDAPI"][k] + print(f"{name} completion times:") + for vv in v: + print(vv) + + for ctrl_idx, trajs in datas.items(): + for ep_id, (traj, traj_time) in enumerate(trajs): + data = [str(traj_time)] + for t_x_y_yaw in traj: + data.append(",".join([str(i) for i in t_x_y_yaw])) + name = ["learned", "bdapi"][ctrl_idx] + with open(f"route_{idx}_ep_{ep_id}_{name}.txt", "w") as f: + f.write("\n".join(data) + "\n") + + +def baseline_navigate(spot, waypoint, limits=True, **kwargs): + goal_x, goal_y, goal_heading = waypoint + if limits: + cmd_id = spot.set_base_position( + x_pos=goal_x, + y_pos=goal_y, + yaw=goal_heading, + end_time=100, + max_fwd_vel=0.5, + max_hor_vel=0.05, + max_ang_vel=np.deg2rad(30), + ) + else: + cmd_id = spot.set_base_position( + x_pos=goal_x, y_pos=goal_y, yaw=goal_heading, end_time=100 + ) + cmd_status = None + success = False + traj = [] + st = time.time() + while not success and time.time() < st + 20: + if cmd_status != 1: + traj.append((time.time(), *spot.get_xy_yaw())) + time.sleep(0.5) + feedback_resp = spot.get_cmd_feedback(cmd_id) + cmd_status = ( + feedback_resp.feedback.synchronized_feedback.mobility_command_feedback + ).se2_trajectory_feedback.status + else: + if limits: + cmd_id = spot.set_base_position( + x_pos=goal_x, + y_pos=goal_y, + yaw=goal_heading, + end_time=100, + max_fwd_vel=0.5, + max_hor_vel=0.05, + max_ang_vel=np.deg2rad(30), + ) + else: + cmd_id = spot.set_base_position( + x_pos=goal_x, y_pos=goal_y, yaw=goal_heading, end_time=100 + ) + + x, y, yaw = spot.get_xy_yaw() + dist = np.linalg.norm(np.array([x, y]) - np.array([goal_x, goal_y])) + heading_diff = abs(wrap_heading(goal_heading - yaw)) + success = dist < 0.3 and heading_diff < np.deg2rad(5) + + return traj + + +def learned_navigate(waypoint, policy, env, **kwargs): + goal_x, goal_y, goal_heading = waypoint + observations = env.reset((goal_x, goal_y), goal_heading) + done = False + policy.reset() + traj = [] + while not done: + traj.append((time.time(), *env.spot.get_xy_yaw())) + action = policy.act(observations) + observations, _, done, _ = env.step(base_action=action) + + return traj + + +def return_to_start(spot, waypoint, policy, env, no_learn=False): + # goal_x, goal_y, goal_heading = waypoint + if not no_learn: + learned_navigate(waypoint, policy, env) + baseline_navigate(spot, waypoint, limits=False) + # spot.set_base_position( + # x_pos=goal_x, y_pos=goal_y, yaw=goal_heading, end_time=100, blocking=True + # ) + + +if __name__ == "__main__": + spot = Spot("NavCompare") + parser = argparse.ArgumentParser() + parser.add_argument("idx", type=int) + args = parser.parse_args() + with spot.get_lease(hijack=True): + main(spot, args.idx) diff --git a/spot_rl_experiments/generate_executables.py b/spot_rl_experiments/generate_executables.py new file mode 100644 index 00000000..6d8e6f1e --- /dev/null +++ b/spot_rl_experiments/generate_executables.py @@ -0,0 +1,44 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import os +import os.path as osp +import sys + +this_dir = osp.dirname(osp.abspath(__file__)) +base_dir = osp.join(this_dir, "spot_rl") +bin_dir = osp.join(os.environ["CONDA_PREFIX"], "bin") + +orig_to_alias = { + "envs.gaze_env": "spot_rl_gaze_env", + "envs.mobile_manipulation_env": "spot_rl_mobile_manipulation_env", + "envs.nav_env": "spot_rl_nav_env", + "envs.place_env": "spot_rl_place_env", + "baselines.go_to_waypoint": "spot_rl_go_to_waypoint", + "utils.autodock": "spot_rl_autodock", + "utils.waypoint_recorder": "spot_rl_waypoint_recorder", + "ros_img_vis": "spot_rl_ros_img_vis", + "launch/core.sh": "spot_rl_launch_core", + "launch/local_listener.sh": "spot_rl_launch_listener", + "launch/local_only.sh": "spot_rl_launch_local", + "launch/kill_sessions.sh": "spot_rl_kill_sessions", +} + +print("Generating executables...") +for orig, alias in orig_to_alias.items(): + exe_path = osp.join(bin_dir, alias) + if orig.endswith(".sh"): + data = f"#!/usr/bin/env bash \nsource {osp.join(base_dir, orig)}\n" + else: + data = f"#!/usr/bin/env bash \n{sys.executable} -m spot_rl.{orig} $@\n" + with open(exe_path, "w") as f: + f.write(data) + os.chmod(exe_path, 33277) + print("Added:", alias) +print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") +print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") +print("THESE EXECUTABLES ARE ONLY VISIBLE TO THE CURRENT CONDA ENV!!") +print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") +print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") diff --git a/spot_rl_experiments/setup.py b/spot_rl_experiments/setup.py new file mode 100644 index 00000000..292d8af1 --- /dev/null +++ b/spot_rl_experiments/setup.py @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import setuptools + +setuptools.setup( + name="spot_rl", + version="0.1", + author="Naoki Yokoyama", + author_email="naokiyokoyama@github", + description="Python wrapper for Boston Dynamics Spot Robot", + url="https://github.com/naokiyokoyama/spot_rl_wrapper", + packages=setuptools.find_packages(), +) diff --git a/spot_rl_experiments/spot_rl/baselines/go_to_waypoint.py b/spot_rl_experiments/spot_rl/baselines/go_to_waypoint.py new file mode 100644 index 00000000..03833550 --- /dev/null +++ b/spot_rl_experiments/spot_rl/baselines/go_to_waypoint.py @@ -0,0 +1,70 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import os +import time + +import numpy as np +from spot_rl.utils.utils import get_default_parser, nav_target_from_waypoints +from spot_wrapper.spot import Spot + +DOCK_ID = int(os.environ.get("SPOT_DOCK_ID", 520)) + + +def main(spot): + parser = get_default_parser() + parser.add_argument("-g", "--goal") + parser.add_argument("-w", "--waypoint") + parser.add_argument("-d", "--dock", action="store_true") + parser.add_argument("-l", "--limit", action="store_true") + args = parser.parse_args() + if args.waypoint is not None: + goal_x, goal_y, goal_heading = nav_target_from_waypoints(args.waypoint) + else: + assert args.goal is not None + goal_x, goal_y, goal_heading = [float(i) for i in args.goal.split(",")] + + if args.limit: + kwargs = { + "max_fwd_vel": 0.5, + "max_hor_vel": 0.05, + "max_ang_vel": np.deg2rad(30), + } + else: + kwargs = {} + + spot.power_on() + spot.blocking_stand() + try: + cmd_id = spot.set_base_position( + x_pos=goal_x, + y_pos=goal_y, + yaw=goal_heading, + end_time=100, + **kwargs, + ) + cmd_status = None + while cmd_status != 1: + time.sleep(0.1) + feedback_resp = spot.get_cmd_feedback(cmd_id) + cmd_status = ( + feedback_resp.feedback.synchronized_feedback.mobility_command_feedback + ).se2_trajectory_feedback.status + if args.dock: + dock_start_time = time.time() + while time.time() - dock_start_time < 2: + try: + spot.dock(dock_id=DOCK_ID, home_robot=True) + except Exception: + print("Dock not found... trying again") + time.sleep(0.1) + finally: + spot.power_off() + + +if __name__ == "__main__": + spot = Spot("GoToWaypoint") + with spot.get_lease(hijack=True): + main(spot) diff --git a/spot_rl_experiments/spot_rl/envs/base_env.py b/spot_rl_experiments/spot_rl/envs/base_env.py new file mode 100644 index 00000000..195be806 --- /dev/null +++ b/spot_rl_experiments/spot_rl/envs/base_env.py @@ -0,0 +1,838 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +# mypy: ignore-errors +import os +import os.path as osp +import time + +import cv2 +import gym +from spot_rl.utils.img_publishers import MAX_HAND_DEPTH +from spot_rl.utils.mask_rcnn_utils import ( + generate_mrcnn_detections, + get_deblurgan_model, + get_mrcnn_model, + pred2string, +) +from spot_rl.utils.robot_subscriber import SpotRobotSubscriberMixin +from spot_rl.utils.stopwatch import Stopwatch + +try: + import magnum as mn +except Exception: + pass + +import numpy as np +import quaternion +import rospy + +try: + from deblur_gan.predictor import DeblurGANv2 + from mask_rcnn_detectron2.inference import MaskRcnnInference +except Exception: + pass + +from sensor_msgs.msg import Image +from spot_rl.utils.utils import FixSizeOrderedDict, arr2str, object_id_to_object_name +from spot_rl.utils.utils import ros_topics as rt +from spot_wrapper.spot import Spot, wrap_heading +from std_msgs.msg import Float32, String + +MAX_CMD_DURATION = 5 +GRASP_VIS_DIR = osp.join( + osp.dirname(osp.dirname(osp.abspath(__file__))), "grasp_visualizations" +) +if not osp.isdir(GRASP_VIS_DIR): + os.mkdir(GRASP_VIS_DIR) + +DETECTIONS_BUFFER_LEN = 30 +LEFT_CROP = 124 +RIGHT_CROP = 60 +NEW_WIDTH = 228 +NEW_HEIGHT = 240 +ORIG_WIDTH = 640 +ORIG_HEIGHT = 480 +WIDTH_SCALE = 0.5 +HEIGHT_SCALE = 0.5 + + +def pad_action(action): + """We only control 4 out of 6 joints; add zeros to non-controllable indices.""" + return np.array([*action[:3], 0.0, action[3], 0.0]) + + +def rescale_actions(actions, action_thresh=0.05, silence_only=False): + actions = np.clip(actions, -1, 1) + # Silence low actions + actions[np.abs(actions) < action_thresh] = 0.0 + if silence_only: + return actions + + # Remap action scaling to compensate for silenced values + action_offsets = np.ones_like(actions) * action_thresh + action_offsets[actions < 0] = -action_offsets[actions < 0] + action_offsets[actions == 0] = 0 + actions = (actions - np.array(action_offsets)) / (1.0 - action_thresh) + + return actions + + +class SpotBaseEnv(SpotRobotSubscriberMixin, gym.Env): + node_name = "spot_reality_gym" + no_raw = True + proprioception = True + + def __init__(self, config, spot: Spot, stopwatch=None): + self.detections_buffer = { + k: FixSizeOrderedDict(maxlen=DETECTIONS_BUFFER_LEN) + for k in ["detections", "filtered_depth", "viz"] + } + + super().__init__(spot=spot) + + self.config = config + self.spot = spot + if stopwatch is None: + stopwatch = Stopwatch() + self.stopwatch = stopwatch + + # General environment parameters + self.ctrl_hz = config.CTRL_HZ + self.max_episode_steps = config.MAX_EPISODE_STEPS + self.num_steps = 0 + self.reset_ran = False + + # Base action parameters + self.max_lin_dist = config.MAX_LIN_DIST + self.max_ang_dist = np.deg2rad(config.MAX_ANG_DIST) + + # Arm action parameters + self.initial_arm_joint_angles = np.deg2rad(config.INITIAL_ARM_JOINT_ANGLES) + self.arm_lower_limits = np.deg2rad(config.ARM_LOWER_LIMITS) + self.arm_upper_limits = np.deg2rad(config.ARM_UPPER_LIMITS) + self.locked_on_object_count = 0 + self.grasp_attempted = False + self.place_attempted = False + self.detection_timestamp = -1 + + self.forget_target_object_steps = config.FORGET_TARGET_OBJECT_STEPS + self.curr_forget_steps = 0 + self.obj_center_pixel = None + self.target_obj_name = None + self.last_target_obj = None + self.use_mrcnn = True + self.target_object_distance = -1 + self.detections_str_synced = "None" + self.latest_synchro_obj_detection = None + self.mrcnn_viz = None + self.last_seen_objs = [] + self.slowdown_base = -1 + self.prev_base_moved = False + self.should_end = False + + # Text-to-speech + self.tts_pub = rospy.Publisher(rt.TEXT_TO_SPEECH, String, queue_size=1) + + # Mask RCNN / Gaze + self.parallel_inference_mode = config.PARALLEL_INFERENCE_MODE + if config.PARALLEL_INFERENCE_MODE: + if config.USE_MRCNN: + rospy.Subscriber(rt.DETECTIONS_TOPIC, String, self.detections_cb) + rospy.loginfo( + f"[{self.node_name}]: Parallel inference selected: Waiting for Mask R-CNN msgs..." + ) + st = time.time() + while ( + len(self.detections_buffer["detections"]) == 0 + and time.time() < st + 25 + ): + pass + assert ( + len(self.detections_buffer["detections"]) > 0 + ), "Mask R-CNN msgs not found!" + rospy.loginfo(f"[{self.node_name}]: ...msgs received.") + scale_pub = rospy.Publisher(rt.IMAGE_SCALE, Float32, queue_size=1) + scale_pub.publish(config.IMAGE_SCALE) + elif config.USE_MRCNN: + self.mrcnn = get_mrcnn_model(config) + self.deblur_gan = get_deblurgan_model(config) + + if config.USE_MRCNN: + self.mrcnn_viz_pub = rospy.Publisher( + rt.MASK_RCNN_VIZ_TOPIC, Image, queue_size=1 + ) + + if config.USE_HEAD_CAMERA: + print("Waiting for filtered depth msgs...") + st = time.time() + while self.filtered_head_depth is None and time.time() < st + 15: + pass + assert self.filtered_head_depth is not None, "Depth msgs not found!" + print("...msgs received.") + + @property + def filtered_hand_depth(self): + return self.msgs[rt.FILTERED_HAND_DEPTH] + + @property + def filtered_head_depth(self): + return self.msgs[rt.FILTERED_HEAD_DEPTH] + + @property + def filtered_hand_rgb(self): + return self.msgs[rt.HAND_RGB] + + def detections_cb(self, msg): + timestamp, detections_str = msg.data.split("|") + self.detections_buffer["detections"][int(timestamp)] = detections_str + + def img_callback(self, topic, msg): + super().img_callback(topic, msg) + if topic == rt.MASK_RCNN_VIZ_TOPIC: + self.detections_buffer["viz"][int(msg.header.stamp.nsecs)] = msg + elif topic == rt.FILTERED_HAND_DEPTH: + self.detections_buffer["filtered_depth"][int(msg.header.stamp.nsecs)] = msg + + def say(self, *args): + text = " ".join(args) + print("[base_env.py]: Saying:", text) + self.tts_pub.publish(String(text)) + + def reset(self, target_obj_id=None, *args, **kwargs): + # Reset parameters + self.num_steps = 0 + self.reset_ran = True + self.grasp_attempted = False + self.use_mrcnn = True + self.locked_on_object_count = 0 + self.curr_forget_steps = 0 + self.target_obj_name = target_obj_id + self.last_target_obj = None + self.obj_center_pixel = None + self.place_attempted = False + self.detection_timestamp = -1 + self.slowdown_base = -1 + self.prev_base_moved = False + self.should_end = False + + observations = self.get_observations() + return observations + + def step( # noqa + self, + base_action=None, + arm_action=None, + grasp=False, + place=False, + max_joint_movement_key="MAX_JOINT_MOVEMENT", + nav_silence_only=True, + disable_oa=None, + ): + """Moves the arm and returns updated observations + + :param base_action: np.array of velocities (linear, angular) + :param arm_action: np.array of radians denoting how each joint is to be moved + :param grasp: whether to call the grasp_hand_depth() method + :param place: whether to call the open_gripper() method + :param max_joint_movement_key: max allowable displacement of arm joints + (different for gaze and place) + :return: observations, reward (None), done, info + """ + assert self.reset_ran, ".reset() must be called first!" + target_yaw = None + if disable_oa is None: + disable_oa = self.config.DISABLE_OBSTACLE_AVOIDANCE + grasp = grasp or self.config.GRASP_EVERY_STEP + print(f"raw_base_ac: {arr2str(base_action)}\traw_arm_ac: {arr2str(arm_action)}") + if grasp: + # Briefly pause and get latest gripper image to ensure precise grasp + time.sleep(0.5) + self.get_gripper_images(save_image=True) + + if self.curr_forget_steps == 0: + print(f"GRASP CALLED: Aiming at (x, y): {self.obj_center_pixel}!") + self.say("Grasping " + self.target_obj_name) + + # The following cmd is blocking + success = self.attempt_grasp() + if success: + # Just leave the object on the receptacle if desired + if self.config.DONT_PICK_UP: + print("open_gripper in don't pick up") + self.spot.open_gripper() + self.grasp_attempted = True + arm_positions = np.deg2rad(self.config.PLACE_ARM_JOINT_ANGLES) + else: + self.say("BD grasp API failed.") + self.spot.open_gripper() + self.locked_on_object_count = 0 + arm_positions = np.deg2rad(self.config.GAZE_ARM_JOINT_ANGLES) + time.sleep(2) + + # Revert joint positions after grasp + self.spot.set_arm_joint_positions( + positions=arm_positions, travel_time=1.0 + ) + # Wait for arm to return to position + time.sleep(1.0) + if self.config.TERMINATE_ON_GRASP: + self.should_end = True + elif place: + print("PLACE ACTION CALLED: Opening the gripper!") + if self.get_grasp_angle_to_xy() < np.deg2rad(30): + self.turn_wrist() + print("open gripper in place") + self.spot.open_gripper() + time.sleep(0.3) + self.place_attempted = True + if base_action is not None: + if nav_silence_only: + base_action = rescale_actions(base_action, silence_only=True) + else: + base_action = np.clip(base_action, -1, 1) + if np.count_nonzero(base_action) > 0: + # Command velocities using the input action + lin_dist, ang_dist = base_action + lin_dist *= self.max_lin_dist + ang_dist *= self.max_ang_dist + target_yaw = wrap_heading(self.yaw + ang_dist) + # No horizontal velocity + ctrl_period = 1 / self.ctrl_hz + # Don't even bother moving if it's just for a bit of distance + if abs(lin_dist) < 0.05 and abs(ang_dist) < np.deg2rad(3): + base_action = None + target_yaw = None + else: + base_action = [lin_dist / ctrl_period, 0, ang_dist / ctrl_period] + self.prev_base_moved = True + else: + base_action = None + self.prev_base_moved = False + if arm_action is not None: + arm_action = rescale_actions(arm_action) + if np.count_nonzero(arm_action) > 0: + arm_action *= self.config[max_joint_movement_key] + arm_action = self.current_arm_pose + pad_action(arm_action) + arm_action = np.clip( + arm_action, self.arm_lower_limits, self.arm_upper_limits + ) + else: + arm_action = None + + if not (grasp or place): + if self.slowdown_base > -1 and base_action is not None: + # self.ctrl_hz = self.slowdown_base + base_action = ( + np.array(base_action) * self.slowdown_base + ) # / self.ctrl_hz + if base_action is not None and arm_action is not None: + self.spot.set_base_vel_and_arm_pos( + *base_action, + arm_action, + MAX_CMD_DURATION, + disable_obstacle_avoidance=disable_oa, + ) + elif base_action is not None: + self.spot.set_base_velocity( + *base_action, + MAX_CMD_DURATION, + disable_obstacle_avoidance=disable_oa, + ) + elif arm_action is not None: + self.spot.set_arm_joint_positions( + positions=arm_action, travel_time=1 / self.ctrl_hz * 0.9 + ) + + if self.prev_base_moved and base_action is None: + self.spot.stand() + + print(f"base_action: {arr2str(base_action)}\tarm_action: {arr2str(arm_action)}") + + # Spin until enough time has passed during this step + start_time = time.time() + if base_action is not None or arm_action is not None: + while time.time() < start_time + 1 / self.ctrl_hz: + if target_yaw is not None and abs( + wrap_heading(self.yaw - target_yaw) + ) < np.deg2rad(3): + # Prevent overshooting of angular velocity + self.spot.set_base_velocity(base_action[0], 0, 0, MAX_CMD_DURATION) + target_yaw = None + elif not (grasp or place): + print("!!!! NO ACTIONS CALLED: moving to next step !!!!") + self.num_steps -= 1 + + self.stopwatch.record("run_actions") + if base_action is not None: + self.spot.set_base_velocity(0, 0, 0, 0.5) + + observations = self.get_observations() + self.stopwatch.record("get_observations") + + self.num_steps += 1 + timeout = self.num_steps >= self.max_episode_steps + done = timeout or self.get_success(observations) or self.should_end + self.ctrl_hz = self.config.CTRL_HZ # revert ctrl_hz in case it slowed down + + # Don't need reward or info + reward = None + info = {"num_steps": self.num_steps} + + return observations, reward, done, info + + def attempt_grasp(self): + pre_grasp = time.time() + ret = self.spot.grasp_hand_depth( + self.obj_center_pixel, top_down_grasp=True, timeout=10 + ) + if self.config.USE_REMOTE_SPOT: + ret = time.time() - pre_grasp > 3 # TODO: Make this better... + return ret + + @staticmethod + def get_nav_success(observations, success_distance, success_angle): + # Is the agent at the goal? + dist_to_goal, _ = observations["target_point_goal_gps_and_compass_sensor"] + at_goal = dist_to_goal < success_distance + good_heading = abs(observations["goal_heading"][0]) < success_angle + return at_goal and good_heading + + def print_nav_stats(self, observations): + rho, theta = observations["target_point_goal_gps_and_compass_sensor"] + print( + f"Dist to goal: {rho:.2f}\t" + f"theta: {np.rad2deg(theta):.2f}\t" + f"x: {self.x:.2f}\t" + f"y: {self.y:.2f}\t" + f"yaw: {np.rad2deg(self.yaw):.2f}\t" + f"gh: {np.rad2deg(observations['goal_heading'][0]):.2f}\t" + ) + + def get_nav_observation(self, goal_xy, goal_heading): + observations = {} + + # Get visual observations + front_depth = self.msg_to_cv2(self.filtered_head_depth, "mono8") + + front_depth = cv2.resize( + front_depth, (120 * 2, 212), interpolation=cv2.INTER_AREA + ) + front_depth = np.float32(front_depth) / 255.0 + # Add dimension for channel (unsqueeze) + front_depth = front_depth.reshape(*front_depth.shape[:2], 1) + observations["spot_right_depth"], observations["spot_left_depth"] = np.split( + front_depth, 2, 1 + ) + + # Get rho theta observation + curr_xy = np.array([self.x, self.y], dtype=np.float32) + rho = np.linalg.norm(curr_xy - goal_xy) + theta = wrap_heading( + np.arctan2(goal_xy[1] - self.y, goal_xy[0] - self.x) - self.yaw + ) + rho_theta = np.array([rho, theta], dtype=np.float32) + + # Get goal heading observation + goal_heading_ = -np.array( + [wrap_heading(goal_heading - self.yaw)], dtype=np.float32 + ) + observations["target_point_goal_gps_and_compass_sensor"] = rho_theta + observations["goal_heading"] = goal_heading_ + + return observations + + def get_arm_joints(self): + # Get proprioception inputs + joints = np.array( + [ + j + for idx, j in enumerate(self.current_arm_pose) + if idx not in self.config.JOINT_BLACKLIST + ], + dtype=np.float32, + ) + + return joints + + def get_gripper_images(self, save_image=False): + if self.grasp_attempted: + # Return blank images if the gripper is being blocked + blank_img = np.zeros([NEW_HEIGHT, NEW_WIDTH, 1], dtype=np.float32) + return blank_img, blank_img.copy() + if self.parallel_inference_mode: + self.detection_timestamp = None + # Use .copy() to prevent mutations during iteration + for i in reversed(self.detections_buffer["detections"].copy()): + if ( + i in self.detections_buffer["detections"] + and i in self.detections_buffer["filtered_depth"] + ): + self.detection_timestamp = i + break + if self.detection_timestamp is None: + raise RuntimeError("Could not correctly synchronize gaze observations") + self.detections_str_synced, filtered_hand_depth = ( + self.detections_buffer["detections"][self.detection_timestamp], + self.detections_buffer["filtered_depth"][self.detection_timestamp], + ) + arm_depth = self.msg_to_cv2(filtered_hand_depth, "mono8") + else: + arm_depth = self.msg_to_cv2(self.filtered_hand_depth, "mono8") + + # Crop out black vertical bars on the left and right edges of aligned depth img + arm_depth = arm_depth[:, LEFT_CROP:-RIGHT_CROP] + arm_depth = cv2.resize( + arm_depth, (NEW_WIDTH, NEW_HEIGHT), interpolation=cv2.INTER_AREA + ) + arm_depth = arm_depth.reshape([*arm_depth.shape, 1]) # unsqueeze + arm_depth = np.float32(arm_depth) / 255.0 + + # Generate object mask channel + if self.use_mrcnn: + obj_bbox = self.update_gripper_detections(arm_depth, save_image) + else: + obj_bbox = None + + if obj_bbox is not None: + self.target_object_distance, arm_depth_bbox = get_obj_dist_and_bbox( + obj_bbox, arm_depth + ) + else: + self.target_object_distance = -1 + arm_depth_bbox = np.zeros_like(arm_depth, dtype=np.float32) + + return arm_depth, arm_depth_bbox + + def update_gripper_detections(self, arm_depth, save_image=False): + det = self.get_mrcnn_det(arm_depth, save_image=save_image) + if det is None: + self.curr_forget_steps += 1 + self.locked_on_object_count = 0 + return det + + def get_mrcnn_det(self, arm_depth, save_image=False): + marked_img = None + if self.parallel_inference_mode: + detections_str = str(self.detections_str_synced) + else: + img = self.msg_to_cv2(self.msgs[rt.HAND_RGB]) + if save_image: + marked_img = img.copy() + pred = generate_mrcnn_detections( + img, + scale=self.config.IMAGE_SCALE, + mrcnn=self.mrcnn, + grayscale=True, + deblurgan=self.deblur_gan, + ) + detections_str = pred2string(pred) + + print("DETECTIONS STR") + print(detections_str) + # 0,0.9983996748924255,80.09086608886719,228.12628173828125,171.05816650390625,312.3734130859375 mrcnn + # 303,243,397,332 owlvit + + # If we haven't seen the current target object in a while, look for new ones + if self.curr_forget_steps >= self.forget_target_object_steps: + self.target_obj_name = None + + if detections_str != "None": + detected_classes = [] + for i in detections_str.split(";"): + label = i.split(",")[0] + + # Maskrnn return ids but other object detectors return already the string class + if label.isdigit(): + label = object_id_to_object_name(int(label)) + detected_classes.append(label) + + print("[bounding_box]: Detected:", ", ".join(detected_classes)) + + if self.target_obj_name is None: + most_confident_score = 0.0 + good_detections = [] + for det in detections_str.split(";"): + class_detected, score = det.split(",")[:2] + score = float(score) + if class_detected.isdigit(): + label = object_id_to_object_name(int(class_detected)) + else: + label = class_detected + dist = get_obj_dist_and_bbox(self.get_det_bbox(det), arm_depth)[0] + + if score > 0.001 and dist < MAX_HAND_DEPTH: + good_detections.append(label) + if score > most_confident_score: + most_confident_score = score + most_confident_name = label + if most_confident_score == 0.0: + return None + + if most_confident_name in self.last_seen_objs: + self.target_obj_name = most_confident_name + if self.target_obj_name != self.last_target_obj: + self.say("Now targeting " + self.target_obj_name) + self.last_target_obj = self.target_obj_name + self.last_seen_objs = good_detections + else: + self.last_seen_objs = [] + else: + return None + + # Check if desired object is in view of camera + targ_obj_name = self.target_obj_name + + def correct_class(detection): + class_detected = detection.split(",")[0] + if class_detected.isdigit(): + return object_id_to_object_name(int(class_detected)) == targ_obj_name + else: + return class_detected == targ_obj_name + + matching_detections = [d for d in detections_str.split(";") if correct_class(d)] + + if not matching_detections: + return None + + self.curr_forget_steps = 0 + + # Get object match with the highest score + def get_score(detection): + return float(detection.split(",")[1]) + + best_detection = sorted(matching_detections, key=get_score)[-1] + x1, y1, x2, y2 = self.get_det_bbox(best_detection) + + # Create bbox mask from selected detection + cx = int(np.mean([x1, x2])) + cy = int(np.mean([y1, y2])) + self.obj_center_pixel = (cx, cy) + + if save_image: + if marked_img is None: + while self.detection_timestamp not in self.detections_buffer["viz"]: + pass + viz_img = self.detections_buffer["viz"][self.detection_timestamp] + marked_img = self.cv_bridge.imgmsg_to_cv2(viz_img) + marked_img = cv2.resize( + marked_img, + (0, 0), + fx=1 / self.config.IMAGE_SCALE, + fy=1 / self.config.IMAGE_SCALE, + interpolation=cv2.INTER_AREA, + ) + marked_img = cv2.circle(marked_img, (cx, cy), 5, (0, 0, 255), -1) + marked_img = cv2.rectangle(marked_img, (x1, y1), (x2, y2), (0, 0, 255)) + out_path = osp.join(GRASP_VIS_DIR, f"{time.time()}.png") + cv2.imwrite(out_path, marked_img) + print("Saved grasp image as", out_path) + img_msg = self.cv_bridge.cv2_to_imgmsg(marked_img) + self.mrcnn_viz_pub.publish(img_msg) + + height, width = (480, 640) + locked_on = self.locked_on_object(x1, y1, x2, y2, height, width) + if locked_on: + self.locked_on_object_count += 1 + print(f"Locked on to target {self.locked_on_object_count} time(s)...") + else: + if self.locked_on_object_count > 0: + print("Lost lock-on!") + self.locked_on_object_count = 0 + + return x1, y1, x2, y2 + + def get_det_bbox(self, det): + img_scale_factor = self.config.IMAGE_SCALE + x1, y1, x2, y2 = [int(float(i) / img_scale_factor) for i in det.split(",")[-4:]] + return x1, y1, x2, y2 + + @staticmethod + def locked_on_object(x1, y1, x2, y2, height, width, radius=0.15): + cy, cx = height // 2, width // 2 + # Locked on if the center of the image is in the bbox + if x1 < cx < x2 and y1 < cy < y2: + return True + + pixel_radius = min(height, width) * radius + # Get pixel distance between bbox rectangle and the center of the image + # Stack Overflow question ID #5254838 + dx = np.max([x1 - cx, 0, cx - x2]) + dy = np.max([y1 - cy, 0, cy - y2]) + bbox_dist = np.sqrt(dx**2 + dy**2) + locked_on = bbox_dist < pixel_radius + + return locked_on + + def should_grasp(self): + grasp = False + if self.locked_on_object_count >= self.config.OBJECT_LOCK_ON_NEEDED: + if self.target_object_distance < 1.5: + if self.config.ASSERT_CENTERING: + x, y = self.obj_center_pixel + if abs(x / 640 - 0.5) < 0.25 or abs(y / 480 - 0.5) < 0.25: + grasp = True + else: + print("Too off center to grasp!:", x / 640, y / 480) + else: + print(f"Too far to grasp ({self.target_object_distance})!") + + return grasp + + def get_observations(self): + raise NotImplementedError + + def get_success(self, observations): + raise NotImplementedError + + @staticmethod + def spot2habitat_transform(position, rotation): + x, y, z = position + qx, qy, qz, qw = rotation + + quat = quaternion.quaternion(qw, qx, qy, qz) + rotation_matrix = mn.Quaternion(quat.imag, quat.real).to_matrix() + rotation_matrix_fixed = ( + rotation_matrix + @ mn.Matrix4.rotation( + mn.Rad(-np.pi / 2.0), mn.Vector3(1.0, 0.0, 0.0) + ).rotation() + ) + translation = mn.Vector3(x, z, -y) + + quat_rotated = mn.Quaternion.from_matrix(rotation_matrix_fixed) + quat_rotated.vector = mn.Vector3( + quat_rotated.vector[0], quat_rotated.vector[2], -quat_rotated.vector[1] + ) + rotation_matrix_fixed = quat_rotated.to_matrix() + sim_transform = mn.Matrix4.from_(rotation_matrix_fixed, translation) + + return sim_transform + + @staticmethod + def spot2habitat_translation(spot_translation): + return mn.Vector3(np.array(spot_translation)[np.array([0, 2, 1])]) + + @property + def curr_transform(self): + # Assume body is at default height of 0.5 m + # This is local_T_global. + return mn.Matrix4.from_( + mn.Matrix4.rotation_z(mn.Rad(self.yaw)).rotation(), + mn.Vector3(self.x, self.y, 0.5), + ) + + def get_place_sensor(self): + # The place goal should be provided relative to the local robot frame given that + # the robot is at the place receptacle + gripper_T_base = self.get_in_gripper_tf() + base_T_gripper = gripper_T_base.inverted() + base_frame_place_target = self.get_base_frame_place_target() + hab_place_target = self.spot2habitat_translation(base_frame_place_target) + gripper_pos = base_T_gripper.transform_point(hab_place_target) + + return gripper_pos + + def get_base_frame_place_target(self): + if self.place_target_is_local: + base_frame_place_target = self.place_target + else: + base_frame_place_target = self.get_target_in_base_frame(self.place_target) + return base_frame_place_target + + def get_place_distance(self): + gripper_T_base = self.get_in_gripper_tf() + base_frame_gripper_pos = np.array(gripper_T_base.translation) + base_frame_place_target = self.get_base_frame_place_target() + hab_place_target = self.spot2habitat_translation(base_frame_place_target) + hab_place_target = np.array(hab_place_target) + place_dist = np.linalg.norm(hab_place_target - base_frame_gripper_pos) + xy_dist = np.linalg.norm( + hab_place_target[[0, 2]] - base_frame_gripper_pos[[0, 2]] + ) + z_dist = abs(hab_place_target[1] - base_frame_gripper_pos[1]) + return place_dist, xy_dist, z_dist + + def get_in_gripper_tf(self): + wrist_T_base = self.spot2habitat_transform( + self.link_wr1_position, self.link_wr1_rotation + ) + gripper_T_base = wrist_T_base @ mn.Matrix4.translation(self.ee_gripper_offset) + + return gripper_T_base + + def get_target_in_base_frame(self, place_target): + global_T_local = self.curr_transform.inverted() + local_place_target = np.array(global_T_local.transform_point(place_target)) + local_place_target[1] *= -1 # Still not sure why this is necessary + + return local_place_target + + def get_grasp_object_angle(self, obj_translation): + """Calculates angle between gripper line-of-sight and given global position""" + camera_T_matrix = self.get_gripper_transform() + + # Get object location in camera frame + camera_obj_trans = ( + camera_T_matrix.inverted().transform_point(obj_translation).normalized() + ) + + # Get angle between (normalized) location and unit vector + object_angle = angle_between(camera_obj_trans, mn.Vector3(0, 0, -1)) + + return object_angle + + def get_grasp_angle_to_xy(self): + gripper_tf = self.get_in_gripper_tf() + gripper_cam_position = gripper_tf.translation + below_gripper = gripper_cam_position + mn.Vector3(0.0, -1.0, 0.0) + + # Get below gripper pos in gripper frame + gripper_obj_trans = ( + gripper_tf.inverted().transform_point(below_gripper).normalized() + ) + + # Get angle between (normalized) location and unit vector + object_angle = angle_between(gripper_obj_trans, mn.Vector3(0, 0, -1)) + + return object_angle + + def turn_wrist(self): + arm_positions = np.array(self.current_arm_pose) + arm_positions[-1] = np.deg2rad(90) + self.spot.set_arm_joint_positions(positions=arm_positions, travel_time=0.3) + time.sleep(0.6) + + def power_robot(self): + self.spot.power_on() + # self.say("Standing up") + try: + self.spot.undock() + except Exception: + print("Undocking failed: just standing up instead...") + self.spot.blocking_stand() + + +def get_obj_dist_and_bbox(obj_bbox, arm_depth): + x1, y1, x2, y2 = obj_bbox + x1 = max(int(float(x1 - LEFT_CROP) * WIDTH_SCALE), 0) + x2 = max(int(float(x2 - LEFT_CROP) * WIDTH_SCALE), 0) + y1 = int(float(y1) * HEIGHT_SCALE) + y2 = int(float(y2) * HEIGHT_SCALE) + arm_depth_bbox = np.zeros_like(arm_depth, dtype=np.float32) + arm_depth_bbox[y1:y2, x1:x2] = 1.0 + + # Estimate distance from the gripper to the object + depth_box = arm_depth[y1:y2, x1:x2] + + return np.median(depth_box) * MAX_HAND_DEPTH, arm_depth_bbox + + +def angle_between(v1, v2): + # stack overflow question ID: 2827393 + cosine = np.clip(np.dot(v1, v2), -1.0, 1.0) + object_angle = np.arccos(cosine) + + return object_angle diff --git a/spot_rl_experiments/spot_rl/envs/gaze_env.py b/spot_rl_experiments/spot_rl/envs/gaze_env.py new file mode 100644 index 00000000..8f5a281e --- /dev/null +++ b/spot_rl_experiments/spot_rl/envs/gaze_env.py @@ -0,0 +1,155 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import time + +import cv2 +import numpy as np +from spot_rl.envs.base_env import SpotBaseEnv +from spot_rl.real_policy import GazePolicy +from spot_rl.utils.utils import construct_config, get_default_parser +from spot_wrapper.spot import Spot, wrap_heading + +DEBUG = False + + +def run_env(spot, config, target_obj_id=None, orig_pos=None): + # Don't need head cameras for Gaze + config.USE_HEAD_CAMERA = False + + env = SpotGazeEnv(config, spot) + env.power_robot() + policy = GazePolicy(config.WEIGHTS.GAZE, device=config.DEVICE) + policy.reset() + observations = env.reset(target_obj_id=target_obj_id) + done = False + env.say("Starting episode") + if orig_pos is None: + orig_pos = (float(env.x), float(env.y), np.pi) + while not done: + action = policy.act(observations) + observations, _, done, _ = env.step(arm_action=action) + # print("Returning to original position...") + # baseline_navigate(spot, orig_pos, limits=False) + # print("Returned.") + if done: + while True: + spot.set_base_velocity(0, 0, 0, 1.0) + return done + + +def baseline_navigate(spot, waypoint, limits=True, **kwargs): + goal_x, goal_y, goal_heading = waypoint + if limits: + cmd_id = spot.set_base_position( + x_pos=goal_x, + y_pos=goal_y, + yaw=goal_heading, + end_time=100, + max_fwd_vel=0.5, + max_hor_vel=0.05, + max_ang_vel=np.deg2rad(30), + ) + else: + cmd_id = spot.set_base_position( + x_pos=goal_x, y_pos=goal_y, yaw=goal_heading, end_time=100 + ) + cmd_status = None + success = False + traj = [] + st = time.time() + while not success and time.time() < st + 20: + if cmd_status != 1: + traj.append((time.time(), *spot.get_xy_yaw())) + time.sleep(0.5) + feedback_resp = spot.get_cmd_feedback(cmd_id) + cmd_status = ( + feedback_resp.feedback.synchronized_feedback.mobility_command_feedback + ).se2_trajectory_feedback.status + else: + if limits: + cmd_id = spot.set_base_position( + x_pos=goal_x, + y_pos=goal_y, + yaw=goal_heading, + end_time=100, + max_fwd_vel=0.5, + max_hor_vel=0.05, + max_ang_vel=np.deg2rad(30), + ) + else: + cmd_id = spot.set_base_position( + x_pos=goal_x, y_pos=goal_y, yaw=goal_heading, end_time=100 + ) + + x, y, yaw = spot.get_xy_yaw() + dist = np.linalg.norm(np.array([x, y]) - np.array([goal_x, goal_y])) + heading_diff = abs(wrap_heading(goal_heading - yaw)) + success = dist < 0.3 and heading_diff < np.deg2rad(5) + + return traj + + +def close_enough(pos1, pos2): + dist = np.linalg.norm(np.array([pos1[0] - pos2[0], pos1[1] - pos2[1]])) + theta = abs(wrap_heading(pos1[2] - pos2[2])) + return dist < 0.1 and theta < np.deg2rad(2) + + +class SpotGazeEnv(SpotBaseEnv): + def reset(self, target_obj_id=None, *args, **kwargs): + # Move arm to initial configuration + cmd_id = self.spot.set_arm_joint_positions( + positions=self.initial_arm_joint_angles, travel_time=1 + ) + self.spot.block_until_arm_arrives(cmd_id, timeout_sec=1) + print("Open gripper called in Gaze") + self.spot.open_gripper() + + observations = super().reset(target_obj_id=target_obj_id, *args, **kwargs) + + # Reset parameters + self.locked_on_object_count = 0 + if target_obj_id is None: + self.target_obj_name = self.config.TARGET_OBJ_NAME + + return observations + + def step(self, base_action=None, arm_action=None, grasp=False, place=False): + grasp = self.should_grasp() + + observations, reward, done, info = super().step( + base_action, arm_action, grasp, place + ) + + return observations, reward, done, info + + def get_observations(self): + arm_depth, arm_depth_bbox = self.get_gripper_images() + if DEBUG: + img = np.uint8(arm_depth_bbox * 255).reshape(*arm_depth_bbox.shape[:2]) + img2 = np.uint8(arm_depth * 255).reshape(*arm_depth.shape[:2]) + cv2.imwrite(f"arm_bbox_{self.num_steps:03}.png", img) + cv2.imwrite(f"arm_depth_{self.num_steps:03}.png", img2) + observations = { + "joint": self.get_arm_joints(), + "arm_depth": arm_depth, + "arm_depth_bbox": arm_depth_bbox, + } + + return observations + + def get_success(self, observations): + return self.grasp_attempted + + +if __name__ == "__main__": + spot = Spot("RealGazeEnv") + parser = get_default_parser() + parser.add_argument("--target-object", "-t") + args = parser.parse_args() + config = construct_config(args.opts) + with spot.get_lease(hijack=True): + run_env(spot, config, target_obj_id=args.target_object) diff --git a/spot_rl_experiments/spot_rl/envs/lang_env.py b/spot_rl_experiments/spot_rl/envs/lang_env.py new file mode 100644 index 00000000..d2c77e37 --- /dev/null +++ b/spot_rl_experiments/spot_rl/envs/lang_env.py @@ -0,0 +1,421 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import os +import subprocess +import time +from collections import Counter + +import magnum as mn +import numpy as np +import rospy +from hydra import compose, initialize +from spot_rl.envs.base_env import SpotBaseEnv +from spot_rl.envs.gaze_env import SpotGazeEnv +from spot_rl.llm.src.rearrange_llm import RearrangeEasyChain +from spot_rl.models.sentence_similarity import SentenceSimilarity +from spot_rl.real_policy import GazePolicy, MixerPolicy, NavPolicy, PlacePolicy +from spot_rl.utils.remote_spot import RemoteSpot +from spot_rl.utils.utils import ( + closest_clutter, + construct_config, + get_clutter_amounts, + get_default_parser, + get_waypoint_yaml, + nav_target_from_waypoints, + object_id_to_nav_waypoint, + place_target_from_waypoints, +) +from spot_rl.utils.whisper_translator import WhisperTranslator +from spot_wrapper.spot import Spot + +DOCK_ID = int(os.environ.get("SPOT_DOCK_ID", 520)) + + +def main(spot, use_mixer, config, out_path=None): + + if use_mixer: + policy = MixerPolicy( + config.WEIGHTS.MIXER, + config.WEIGHTS.NAV, + config.WEIGHTS.GAZE, + config.WEIGHTS.PLACE, + device=config.DEVICE, + ) + env_class = SpotMobileManipulationBaseEnv + else: + policy = SequentialExperts( + config.WEIGHTS.NAV, + config.WEIGHTS.GAZE, + config.WEIGHTS.PLACE, + device=config.DEVICE, + ) + env_class = SpotMobileManipulationSeqEnv + + env = env_class(config, spot) + + # Reset the viz params + rospy.set_param("/viz_pick", "None") + rospy.set_param("/viz_object", "None") + rospy.set_param("/viz_place", "None") + + # Check if robot should return to base + return_to_base = config.RETURN_TO_BASE + + # Get the waypoints from waypoints.yaml + waypoints = get_waypoint_yaml() + + audio_to_text = WhisperTranslator() + sentence_similarity = SentenceSimilarity() + with initialize(config_path="../llm/src/conf"): + llm_config = compose(config_name="config") + llm = RearrangeEasyChain(llm_config) + + print( + "I am ready to take instructions!\n Sample Instructions : take the rubik cube from the dining table to the hamper" + ) + print("-" * 100) + input("Are you ready?") + audio_to_text.record() + instruction = audio_to_text.translate() + print("Transcribed instructions : ", instruction) + + # Use LLM to convert user input to an instructions set + # Eg: nav_1, pick, nav_2 = 'bowl_counter', "container", 'coffee_counter' + nav_1, pick, nav_2, _ = llm.parse_instructions(instruction) + print("PARSED", nav_1, pick, nav_2) + + # Find closest nav_targets to the ones robot knows locations of + nav_1 = sentence_similarity.get_most_similar_in_list( + nav_1, list(waypoints["nav_targets"].keys()) + ) + nav_2 = sentence_similarity.get_most_similar_in_list( + nav_2, list(waypoints["nav_targets"].keys()) + ) + print("MOST SIMILAR: ", nav_1, pick, nav_2) + + # Used for Owlvit + rospy.set_param("object_target", pick) + + # Used for Visualizations + rospy.set_param("viz_pick", nav_1) + rospy.set_param("viz_object", pick) + rospy.set_param("viz_place", nav_2) + + env.power_robot() + time.sleep(1) + out_data = [] + + waypoint = nav_target_from_waypoints(nav_1) + observations = env.reset(waypoint=waypoint) + + policy.reset() + done = False + if use_mixer: + expert = None + else: + expert = Tasks.NAV + env.stopwatch.reset() + while not done: + out_data.append((time.time(), env.x, env.y, env.yaw)) + base_action, arm_action = policy.act(observations, expert=expert) + nav_silence_only = True + env.stopwatch.record("policy_inference") + observations, _, done, info = env.step( + base_action=base_action, + arm_action=arm_action, + nav_silence_only=nav_silence_only, + ) + + if use_mixer and info.get("grasp_success", False): + policy.policy.prev_nav_masks *= 0 + + if not use_mixer: + expert = info["correct_skill"] + print("Expert:", expert) + + # We reuse nav, so we have to reset it before we use it again. + if not use_mixer and expert != Tasks.NAV: + policy.nav_policy.reset() + + env.stopwatch.print_stats(latest=True) + + # Go to the dock + env.say(f"Finished object rearrangement. RETURN_TO_BASE - {return_to_base}.") + if return_to_base: + waypoint = nav_target_from_waypoints("dock") + observations = env.reset(waypoint=waypoint) + expert = Tasks.NAV + + while True: + base_action, arm_action = policy.act(observations, expert=expert) + nav_silence_only = True + env.stopwatch.record("policy_inference") + observations, _, done, info = env.step( + base_action=base_action, + arm_action=arm_action, + nav_silence_only=nav_silence_only, + ) + try: + spot.dock(dock_id=DOCK_ID, home_robot=True) + spot.home_robot() + break + except Exception: + print("Dock not found... trying again") + time.sleep(0.1) + else: + env.say("Since RETURN_TO_BASE was set to false in config.yaml, will sit down.") + time.sleep(2) + spot.sit() + + print("Done!") + + out_data.append((time.time(), env.x, env.y, env.yaw)) + + if out_path is not None: + data = ( + "\n".join([",".join([str(i) for i in t_x_y_yaw]) for t_x_y_yaw in out_data]) + + "\n" + ) + with open(out_path, "w") as f: + f.write(data) + + +class Tasks: + r"""Enumeration of types of tasks.""" + + NAV = "nav" + GAZE = "gaze" + PLACE = "place" + + +class SequentialExperts: + def __init__(self, nav_weights, gaze_weights, place_weights, device="cuda"): + print("Loading nav_policy...") + self.nav_policy = NavPolicy(nav_weights, device) + print("Loading gaze_policy...") + self.gaze_policy = GazePolicy(gaze_weights, device) + print("Loading place_policy...") + self.place_policy = PlacePolicy(place_weights, device) + print("Done loading all policies!") + + def reset(self): + self.nav_policy.reset() + self.gaze_policy.reset() + self.place_policy.reset() + + def act(self, observations, expert): + base_action, arm_action = None, None + if expert == Tasks.NAV: + base_action = self.nav_policy.act(observations) + elif expert == Tasks.GAZE: + arm_action = self.gaze_policy.act(observations) + elif expert == Tasks.PLACE: + arm_action = self.place_policy.act(observations) + + return base_action, arm_action + + +class SpotMobileManipulationBaseEnv(SpotGazeEnv): + node_name = "SpotMobileManipulationBaseEnv" + + def __init__(self, config, spot: Spot): + super().__init__(config, spot) + + # Nav + self.goal_xy = None + self.goal_heading = None + self.succ_distance = config.SUCCESS_DISTANCE + self.succ_angle = np.deg2rad(config.SUCCESS_ANGLE_DIST) + self.gaze_nav_target = None + self.place_nav_target = None + self.rho = float("inf") + self.heading_err = float("inf") + + # Gaze + self.locked_on_object_count = 0 + self.target_obj_name = config.TARGET_OBJ_NAME + + # Place + self.place_target = None + self.ee_gripper_offset = mn.Vector3(config.EE_GRIPPER_OFFSET) + self.place_target_is_local = False + + # General + self.max_episode_steps = 1000 + self.navigating_to_place = False + + def reset(self, waypoint=None, *args, **kwargs): + # Move arm to initial configuration (w/ gripper open) + self.spot.set_arm_joint_positions( + positions=np.deg2rad(self.config.GAZE_ARM_JOINT_ANGLES), travel_time=0.75 + ) + # Wait for arm to arrive to position + # import pdb; pdb.set_trace() + time.sleep(0.75) + print("open gripper called in SpotMobileManipulationBaseEnv") + self.spot.open_gripper() + + # Nav + if waypoint is None: + self.goal_xy = None + self.goal_heading = None + else: + self.goal_xy, self.goal_heading = (waypoint[:2], waypoint[2]) + + # Place + self.place_target = mn.Vector3(-1.0, -1.0, -1.0) + + # General + self.navigating_to_place = False + + return SpotBaseEnv.reset(self) + + def step(self, base_action, arm_action, *args, **kwargs): + # import pdb; pdb.set_trace() + _, xy_dist, z_dist = self.get_place_distance() + place = xy_dist < self.config.SUCC_XY_DIST and z_dist < self.config.SUCC_Z_DIST + if place: + print("place is true") + + if self.grasp_attempted: + grasp = False + else: + grasp = self.should_grasp() + + if self.grasp_attempted: + max_joint_movement_key = "MAX_JOINT_MOVEMENT_2" + else: + max_joint_movement_key = "MAX_JOINT_MOVEMENT" + + # Slow the base down if we are close to the nav target for grasp to limit blur + if ( + not self.grasp_attempted + and self.rho < 0.5 + and abs(self.heading_err) < np.rad2deg(45) + ): + self.slowdown_base = 0.5 # Hz + print("!!!!!!Slow mode!!!!!!") + else: + self.slowdown_base = -1 + disable_oa = False if self.rho > 0.3 and self.config.USE_OA_FOR_NAV else None + observations, reward, done, info = SpotBaseEnv.step( + self, + base_action=base_action, + arm_action=arm_action, + grasp=grasp, + place=place, + max_joint_movement_key=max_joint_movement_key, + disable_oa=disable_oa, + *args, + **kwargs, + ) + if done: + print("done is true") + + if self.grasp_attempted and not self.navigating_to_place: + # Determine where to go based on what object we've just grasped + waypoint_name = rospy.get_param("/viz_place") + waypoint = nav_target_from_waypoints(waypoint_name) + + self.say("Navigating to " + waypoint_name) + self.place_target = place_target_from_waypoints(waypoint_name) + self.goal_xy, self.goal_heading = (waypoint[:2], waypoint[2]) + self.navigating_to_place = True + info["grasp_success"] = True + + return observations, reward, done, info + + def get_observations(self): + observations = self.get_nav_observation(self.goal_xy, self.goal_heading) + rho = observations["target_point_goal_gps_and_compass_sensor"][0] + self.rho = rho + goal_heading = observations["goal_heading"][0] + self.heading_err = goal_heading + self.use_mrcnn = True + observations.update(super().get_observations()) + observations["obj_start_sensor"] = self.get_place_sensor() + + return observations + + def get_success(self, observations): + return self.place_attempted + + +class SpotMobileManipulationSeqEnv(SpotMobileManipulationBaseEnv): + node_name = "SpotMobileManipulationSeqEnv" + + def __init__(self, config, spot: Spot): + super().__init__(config, spot) + self.current_task = Tasks.NAV + self.timeout_start = float("inf") + + def reset(self, *args, **kwargs): + observations = super().reset(*args, **kwargs) + self.current_task = Tasks.NAV + self.target_obj_name = 0 + self.timeout_start = float("inf") + + return observations + + def step(self, *args, **kwargs): + pre_step_navigating_to_place = self.navigating_to_place + observations, reward, done, info = super().step(*args, **kwargs) + + if self.current_task != Tasks.GAZE: + # Disable target searching if we are not gazing + self.last_seen_objs = [] + + if self.current_task == Tasks.NAV and self.get_nav_success( + observations, self.succ_distance, self.succ_angle + ): + if not self.grasp_attempted: + self.current_task = Tasks.GAZE + self.timeout_start = time.time() + self.target_obj_name = None + else: + self.current_task = Tasks.PLACE + self.say("Starting place") + self.timeout_start = time.time() + + if self.current_task == Tasks.PLACE and time.time() > self.timeout_start + 10: + # call place after 10s of trying + print("Place failed to reach target") + self.spot.rotate_gripper_with_delta(wrist_roll=1.57) + spot.open_gripper() + time.sleep(0.75) + done = True + + if not pre_step_navigating_to_place and self.navigating_to_place: + # This means that the Gaze task has just ended + self.current_task = Tasks.NAV + + info["correct_skill"] = self.current_task + + self.use_mrcnn = self.current_task == Tasks.GAZE + + # + + return observations, reward, done, info + + +if __name__ == "__main__": + parser = get_default_parser() + parser.add_argument("-m", "--use-mixer", action="store_true") + parser.add_argument("--output") + args = parser.parse_args() + config = construct_config(args.opts) + spot = (RemoteSpot if config.USE_REMOTE_SPOT else Spot)("RealSeqEnv") + if config.USE_REMOTE_SPOT: + try: + main(spot, args.use_mixer, config, args.output) + finally: + spot.power_off() + else: + with spot.get_lease(hijack=True): + try: + main(spot, args.use_mixer, config, args.output) + finally: + spot.power_off() diff --git a/spot_rl_experiments/spot_rl/envs/mobile_manipulation_env.py b/spot_rl_experiments/spot_rl/envs/mobile_manipulation_env.py new file mode 100644 index 00000000..d3685bf6 --- /dev/null +++ b/spot_rl_experiments/spot_rl/envs/mobile_manipulation_env.py @@ -0,0 +1,419 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import os +import time +from collections import Counter + +import magnum as mn +import numpy as np +import rospy +from spot_rl.envs.base_env import SpotBaseEnv +from spot_rl.envs.gaze_env import SpotGazeEnv +from spot_rl.real_policy import GazePolicy, MixerPolicy, NavPolicy, PlacePolicy +from spot_rl.utils.remote_spot import RemoteSpot +from spot_rl.utils.utils import ( + closest_clutter, + construct_config, + get_clutter_amounts, + get_default_parser, + get_waypoint_yaml, + nav_target_from_waypoints, + object_id_to_nav_waypoint, + place_target_from_waypoints, +) +from spot_wrapper.spot import Spot + +CLUTTER_AMOUNTS = Counter() # type: Counter +CLUTTER_AMOUNTS.update(get_clutter_amounts()) +NUM_OBJECTS = np.sum(list(CLUTTER_AMOUNTS.values())) +DOCK_ID = int(os.environ.get("SPOT_DOCK_ID", 520)) + +DEBUGGING = False + + +def main(spot, use_mixer, config, out_path=None): + if use_mixer: + policy = MixerPolicy( + config.WEIGHTS.MIXER, + config.WEIGHTS.NAV, + config.WEIGHTS.GAZE, + config.WEIGHTS.PLACE, + device=config.DEVICE, + ) + env_class = SpotMobileManipulationBaseEnv + else: + policy = SequentialExperts( + config.WEIGHTS.NAV, + config.WEIGHTS.GAZE, + config.WEIGHTS.PLACE, + device=config.DEVICE, + ) + env_class = SpotMobileManipulationSeqEnv + + env = env_class(config, spot) + + # Reset the viz params + rospy.set_param("/viz_pick", "None") + rospy.set_param("/viz_object", "None") + rospy.set_param("/viz_place", "None") + + # Check if robot should return to base + return_to_base = config.RETURN_TO_BASE + + # Get the waypoints from waypoints.yaml + waypoints = get_waypoint_yaml() + + objects_to_look = [] + for waypoint in waypoints["object_targets"]: + objects_to_look.append(waypoints["object_targets"][waypoint][0]) + rospy.set_param("object_target", ",".join(objects_to_look)) + + env.power_robot() + time.sleep(1) + count = Counter() + out_data = [] + + for trip_idx in range(NUM_OBJECTS + 1): + if trip_idx < NUM_OBJECTS: + # 2 objects per receptacle + clutter_blacklist = [ + i for i in waypoints["clutter"] if count[i] >= CLUTTER_AMOUNTS[i] + ] + waypoint_name, waypoint = closest_clutter( + env.x, env.y, clutter_blacklist=clutter_blacklist + ) + count[waypoint_name] += 1 + env.say("Going to " + waypoint_name + " to search for objects") + rospy.set_param("viz_pick", waypoint_name) + rospy.set_param("viz_object", ",".join(objects_to_look)) + rospy.set_param("viz_place", "None") + else: + env.say( + f"Finished object rearrangement. RETURN_TO_BASE - {return_to_base}." + ) + if return_to_base: + waypoint = nav_target_from_waypoints("dock") + else: + waypoint = None + break + observations = env.reset(waypoint=waypoint) + policy.reset() + done = False + if use_mixer: + expert = None + else: + expert = Tasks.NAV + env.stopwatch.reset() + while not done: + out_data.append((time.time(), env.x, env.y, env.yaw)) + + if use_mixer: + base_action, arm_action = policy.act(observations) + nav_silence_only = policy.nav_silence_only + else: + base_action, arm_action = policy.act(observations, expert=expert) + nav_silence_only = True + env.stopwatch.record("policy_inference") + observations, _, done, info = env.step( + base_action=base_action, + arm_action=arm_action, + nav_silence_only=nav_silence_only, + ) + # if done: + # import pdb; pdb.set_trace() + + if use_mixer and info.get("grasp_success", False): + policy.policy.prev_nav_masks *= 0 + + if not use_mixer: + expert = info["correct_skill"] + + # Check if the robot has arrived back at the dock + if trip_idx >= NUM_OBJECTS and env.get_nav_success( + observations, config.SUCCESS_DISTANCE, np.deg2rad(10) + ): + # The robot has arrived back at the dock + break + + # Print info + # stats = [f"{k}: {v}" for k, v in info.items()] + # print(" ".join(stats)) + + # We reuse nav, so we have to reset it before we use it again. + if not use_mixer and expert != Tasks.NAV: + policy.nav_policy.reset() + + env.stopwatch.print_stats(latest=True) + + # Ensure gripper is open (place may have timed out) + # if not env.place_attempted: + # print("open gripper in place attempted") + # env.spot.open_gripper() + # time.sleep(2) + + out_data.append((time.time(), env.x, env.y, env.yaw)) + + if out_path is not None: + data = ( + "\n".join([",".join([str(i) for i in t_x_y_yaw]) for t_x_y_yaw in out_data]) + + "\n" + ) + with open(out_path, "w") as f: + f.write(data) + + if return_to_base: + env.say("Executing automatic docking") + dock_start_time = time.time() + while time.time() - dock_start_time < 2: + try: + spot.dock(dock_id=DOCK_ID, home_robot=True) + except Exception: + print("Dock not found... trying again") + time.sleep(0.1) + else: + env.say("Since RETURN_TO_BASE was set to false in config.yaml, will sit down.") + time.sleep(2) + spot.sit() + + print("Done!") + + +class Tasks: + r"""Enumeration of types of tasks.""" + + NAV = "nav" + GAZE = "gaze" + PLACE = "place" + + +class SequentialExperts: + def __init__(self, nav_weights, gaze_weights, place_weights, device="cuda"): + print("Loading nav_policy...") + self.nav_policy = NavPolicy(nav_weights, device) + print("Loading gaze_policy...") + self.gaze_policy = GazePolicy(gaze_weights, device) + print("Loading place_policy...") + self.place_policy = PlacePolicy(place_weights, device) + print("Done loading all policies!") + + def reset(self): + self.nav_policy.reset() + self.gaze_policy.reset() + self.place_policy.reset() + + def act(self, observations, expert): + base_action, arm_action = None, None + if expert == Tasks.NAV: + base_action = self.nav_policy.act(observations) + elif expert == Tasks.GAZE: + arm_action = self.gaze_policy.act(observations) + elif expert == Tasks.PLACE: + arm_action = self.place_policy.act(observations) + + return base_action, arm_action + + +class SpotMobileManipulationBaseEnv(SpotGazeEnv): + node_name = "SpotMobileManipulationBaseEnv" + + def __init__(self, config, spot: Spot): + super().__init__(config, spot) + + # Nav + self.goal_xy = None + self.goal_heading = None + self.succ_distance = config.SUCCESS_DISTANCE + self.succ_angle = np.deg2rad(config.SUCCESS_ANGLE_DIST) + self.gaze_nav_target = None + self.place_nav_target = None + self.rho = float("inf") + self.heading_err = float("inf") + + # Gaze + self.locked_on_object_count = 0 + self.target_obj_name = config.TARGET_OBJ_NAME + + # Place + self.place_target = None + self.ee_gripper_offset = mn.Vector3(config.EE_GRIPPER_OFFSET) + self.place_target_is_local = False + + # General + self.max_episode_steps = 1000 + self.navigating_to_place = False + + def reset(self, waypoint=None, *args, **kwargs): + # Move arm to initial configuration (w/ gripper open) + self.spot.set_arm_joint_positions( + positions=np.deg2rad(self.config.GAZE_ARM_JOINT_ANGLES), travel_time=0.75 + ) + # Wait for arm to arrive to position + # import pdb; pdb.set_trace() + time.sleep(0.75) + print("open gripper called in SpotMobileManipulationBaseEnv") + self.spot.open_gripper() + + # Nav + if waypoint is None: + self.goal_xy = None + self.goal_heading = None + else: + self.goal_xy, self.goal_heading = (waypoint[:2], waypoint[2]) + + # Place + self.place_target = mn.Vector3(-1.0, -1.0, -1.0) + + # General + self.navigating_to_place = False + + return SpotBaseEnv.reset(self) + + def step(self, base_action, arm_action, *args, **kwargs): + # import pdb; pdb.set_trace() + _, xy_dist, z_dist = self.get_place_distance() + place = xy_dist < self.config.SUCC_XY_DIST and z_dist < self.config.SUCC_Z_DIST + if place: + print("place is true") + + if self.grasp_attempted: + grasp = False + else: + grasp = self.should_grasp() + + if self.grasp_attempted: + max_joint_movement_key = "MAX_JOINT_MOVEMENT_2" + else: + max_joint_movement_key = "MAX_JOINT_MOVEMENT" + + # Slow the base down if we are close to the nav target for grasp to limit blur + if ( + not self.grasp_attempted + and self.rho < 0.5 + and abs(self.heading_err) < np.rad2deg(45) + ): + self.slowdown_base = 0.5 # Hz + print("!!!!!!Slow mode!!!!!!") + else: + self.slowdown_base = -1 + disable_oa = False if self.rho > 0.3 and self.config.USE_OA_FOR_NAV else None + observations, reward, done, info = SpotBaseEnv.step( + self, + base_action=base_action, + arm_action=arm_action, + grasp=grasp, + place=place, + max_joint_movement_key=max_joint_movement_key, + disable_oa=disable_oa, + *args, + **kwargs, + ) + if done: + print("done is true") + + if self.grasp_attempted and not self.navigating_to_place: + # Determine where to go based on what object we've just grasped + waypoint_name, waypoint = object_id_to_nav_waypoint(self.target_obj_name) + self.say("Navigating to " + waypoint_name) + rospy.set_param("viz_object", self.target_obj_name) + rospy.set_param("viz_place", waypoint_name) + self.place_target = place_target_from_waypoints(waypoint_name) + self.goal_xy, self.goal_heading = (waypoint[:2], waypoint[2]) + self.navigating_to_place = True + info["grasp_success"] = True + + return observations, reward, done, info + + def get_observations(self): + observations = self.get_nav_observation(self.goal_xy, self.goal_heading) + rho = observations["target_point_goal_gps_and_compass_sensor"][0] + self.rho = rho + goal_heading = observations["goal_heading"][0] + self.heading_err = goal_heading + self.use_mrcnn = True + observations.update(super().get_observations()) + observations["obj_start_sensor"] = self.get_place_sensor() + + return observations + + def get_success(self, observations): + return self.place_attempted + + +class SpotMobileManipulationSeqEnv(SpotMobileManipulationBaseEnv): + node_name = "SpotMobileManipulationSeqEnv" + + def __init__(self, config, spot: Spot): + super().__init__(config, spot) + self.current_task = Tasks.NAV + self.timeout_start = float("inf") + + def reset(self, *args, **kwargs): + observations = super().reset(*args, **kwargs) + self.current_task = Tasks.NAV + self.target_obj_name = 0 + self.timeout_start = float("inf") + + return observations + + def step(self, *args, **kwargs): + pre_step_navigating_to_place = self.navigating_to_place + observations, reward, done, info = super().step(*args, **kwargs) + + if self.current_task != Tasks.GAZE: + # Disable target searching if we are not gazing + self.last_seen_objs = [] + + if self.current_task == Tasks.NAV and self.get_nav_success( + observations, self.succ_distance, self.succ_angle + ): + if not self.grasp_attempted: + self.current_task = Tasks.GAZE + self.timeout_start = time.time() + self.target_obj_name = None + else: + self.current_task = Tasks.PLACE + self.say("Starting place") + self.timeout_start = time.time() + + if self.current_task == Tasks.PLACE and time.time() > self.timeout_start + 10: + # call place after 10s of trying + print("Place failed to reach target") + self.spot.rotate_gripper_with_delta(wrist_roll=1.57) + spot.open_gripper() + time.sleep(0.75) + done = True + + if not pre_step_navigating_to_place and self.navigating_to_place: + # This means that the Gaze task has just ended + self.current_task = Tasks.NAV + + info["correct_skill"] = self.current_task + + self.use_mrcnn = self.current_task == Tasks.GAZE + + # + + return observations, reward, done, info + + +if __name__ == "__main__": + parser = get_default_parser() + parser.add_argument("-m", "--use-mixer", action="store_true") + parser.add_argument("--output") + args = parser.parse_args() + config = construct_config(args.opts) + spot = (RemoteSpot if config.USE_REMOTE_SPOT else Spot)("RealSeqEnv") + if config.USE_REMOTE_SPOT: + try: + main(spot, args.use_mixer, config, args.output) + finally: + spot.power_off() + else: + with spot.get_lease(hijack=True): + try: + main(spot, args.use_mixer, config, args.output) + finally: + spot.power_off() diff --git a/spot_rl_experiments/spot_rl/envs/nav_env.py b/spot_rl_experiments/spot_rl/envs/nav_env.py new file mode 100644 index 00000000..b082c950 --- /dev/null +++ b/spot_rl_experiments/spot_rl/envs/nav_env.py @@ -0,0 +1,93 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import os +import time + +import numpy as np +from spot_rl.envs.base_env import SpotBaseEnv +from spot_rl.real_policy import NavPolicy +from spot_rl.utils.utils import ( + construct_config, + get_default_parser, + nav_target_from_waypoints, +) +from spot_wrapper.spot import Spot + +DOCK_ID = int(os.environ.get("SPOT_DOCK_ID", 520)) + + +def main(spot): + parser = get_default_parser() + parser.add_argument("-g", "--goal") + parser.add_argument("-w", "--waypoint") + parser.add_argument("-d", "--dock", action="store_true") + args = parser.parse_args() + config = construct_config(args.opts) + + # Don't need gripper camera for Nav + config.USE_MRCNN = False + + policy = NavPolicy(config.WEIGHTS.NAV, device=config.DEVICE) + policy.reset() + + env = SpotNavEnv(config, spot) + env.power_robot() + if args.waypoint is not None: + goal_x, goal_y, goal_heading = nav_target_from_waypoints(args.waypoint) + env.say(f"Navigating to {args.waypoint}") + else: + assert args.goal is not None + goal_x, goal_y, goal_heading = [float(i) for i in args.goal.split(",")] + observations = env.reset((goal_x, goal_y), goal_heading) + done = False + time.sleep(1) + try: + while not done: + action = policy.act(observations) + observations, _, done, _ = env.step(base_action=action) + if args.dock: + env.say("Executing automatic docking") + dock_start_time = time.time() + while time.time() - dock_start_time < 2: + try: + spot.dock(dock_id=DOCK_ID, home_robot=True) + except Exception: + print("Dock not found... trying again") + time.sleep(0.1) + finally: + spot.power_off() + + +class SpotNavEnv(SpotBaseEnv): + def __init__(self, config, spot: Spot): + super().__init__(config, spot) + self.goal_xy = None + self.goal_heading = None + self.succ_distance = config.SUCCESS_DISTANCE + self.succ_angle = np.deg2rad(config.SUCCESS_ANGLE_DIST) + + def reset(self, goal_xy, goal_heading): + self.goal_xy = np.array(goal_xy, dtype=np.float32) + self.goal_heading = goal_heading + observations = super().reset() + assert len(self.goal_xy) == 2 + + return observations + + def get_success(self, observations): + succ = self.get_nav_success(observations, self.succ_distance, self.succ_angle) + if succ: + self.spot.set_base_velocity(0.0, 0.0, 0.0, 1 / self.ctrl_hz) + return succ + + def get_observations(self): + return self.get_nav_observation(self.goal_xy, self.goal_heading) + + +if __name__ == "__main__": + spot = Spot("RealNavEnv") + with spot.get_lease(hijack=True): + main(spot) diff --git a/spot_rl_experiments/spot_rl/envs/place_env.py b/spot_rl_experiments/spot_rl/envs/place_env.py new file mode 100644 index 00000000..f67eae9d --- /dev/null +++ b/spot_rl_experiments/spot_rl/envs/place_env.py @@ -0,0 +1,91 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import magnum as mn +import numpy as np +from spot_rl.envs.base_env import SpotBaseEnv +from spot_rl.real_policy import PlacePolicy +from spot_rl.utils.utils import ( + construct_config, + get_default_parser, + place_target_from_waypoints, +) +from spot_wrapper.spot import Spot + + +def main(spot): + parser = get_default_parser() + parser.add_argument("-p", "--place_target") + parser.add_argument("-w", "--waypoint") + parser.add_argument("-l", "--target_is_local", action="store_true") + args = parser.parse_args() + config = construct_config(args.opts) + + # Don't need cameras for Place + config.USE_HEAD_CAMERA = False + config.USE_MRCNN = False + + if args.waypoint is not None: + assert not args.target_is_local + place_target = place_target_from_waypoints(args.waypoint) + else: + assert args.place_target is not None + place_target = [float(i) for i in args.place_target.split(",")] + env = SpotPlaceEnv(config, spot, place_target, args.target_is_local) + env.power_robot() + policy = PlacePolicy(config.WEIGHTS.PLACE, device=config.DEVICE) + policy.reset() + observations = env.reset() + done = False + env.say("Starting episode") + while not done: + action = policy.act(observations) + observations, _, done, _ = env.step(arm_action=action) + if done: + while True: + env.reset() + spot.set_base_velocity(0, 0, 0, 1.0) + + +class SpotPlaceEnv(SpotBaseEnv): + def __init__(self, config, spot: Spot, place_target, target_is_local=False): + super().__init__(config, spot) + self.place_target = np.array(place_target) + self.place_target_is_local = target_is_local + self.ee_gripper_offset = mn.Vector3(config.EE_GRIPPER_OFFSET) + self.placed = False + + def reset(self, *args, **kwargs): + # Move arm to initial configuration + cmd_id = self.spot.set_arm_joint_positions( + positions=self.initial_arm_joint_angles, travel_time=0.75 + ) + self.spot.block_until_arm_arrives(cmd_id, timeout_sec=2) + + observations = super(SpotPlaceEnv, self).reset() + self.placed = False + return observations + + def step(self, place=False, *args, **kwargs): + _, xy_dist, z_dist = self.get_place_distance() + place = xy_dist < self.config.SUCC_XY_DIST and z_dist < self.config.SUCC_Z_DIST + return super().step(place=place, *args, **kwargs) + + def get_success(self, observations): + return self.place_attempted + + def get_observations(self): + observations = { + "joint": self.get_arm_joints(), + "obj_start_sensor": self.get_place_sensor(), + } + + return observations + + +if __name__ == "__main__": + spot = Spot("RealPlaceEnv") + with spot.get_lease(hijack=True): + main(spot) diff --git a/spot_rl_experiments/spot_rl/launch/core.sh b/spot_rl_experiments/spot_rl/launch/core.sh new file mode 100644 index 00000000..70a992ff --- /dev/null +++ b/spot_rl_experiments/spot_rl/launch/core.sh @@ -0,0 +1,19 @@ +echo "Killing all tmux sessions..." +tmux kill-session -t roscore +tmux kill-session -t headless_estop +tmux kill-session -t img_pub +tmux kill-session -t propio_pub +tmux kill-session -t tts_sub +tmux kill-session -t remote_spot_listener +sleep 2 +echo "Starting roscore tmux..." +tmux new -s roscore -d '$CONDA_PREFIX/bin/roscore' +sleep 1 +echo "Starting other tmux nodes" +tmux new -s headless_estop -d '$CONDA_PREFIX/bin/python -m spot_wrapper.headless_estop' +tmux new -s img_pub -d '$CONDA_PREFIX/bin/python -m spot_rl.utils.img_publishers --core' +tmux new -s propio_pub -d '$CONDA_PREFIX/bin/python -m spot_rl.utils.helper_nodes --proprioception' +tmux new -s tts_sub -d '$CONDA_PREFIX/bin/python -m spot_rl.utils.helper_nodes --text-to-speech' +tmux new -s remote_spot_listener -d 'while true; do $CONDA_PREFIX/bin/python -m spot_rl.utils.remote_spot_listener ; done' +sleep 3 +tmux ls diff --git a/spot_rl_experiments/spot_rl/launch/kill_sessions.sh b/spot_rl_experiments/spot_rl/launch/kill_sessions.sh new file mode 100644 index 00000000..1542b1e3 --- /dev/null +++ b/spot_rl_experiments/spot_rl/launch/kill_sessions.sh @@ -0,0 +1,9 @@ +echo "Killing all tmux sessions..." +tmux kill-session -t roscore +tmux kill-session -t headless_estop +tmux kill-session -t img_pub +tmux kill-session -t propio_pub +tmux kill-session -t tts_sub +tmux kill-session -t remote_spot_listener +echo "Here are your remaining tmux sessions:" +tmux ls diff --git a/spot_rl_experiments/spot_rl/launch/local_listener.sh b/spot_rl_experiments/spot_rl/launch/local_listener.sh new file mode 100644 index 00000000..38a27943 --- /dev/null +++ b/spot_rl_experiments/spot_rl/launch/local_listener.sh @@ -0,0 +1,7 @@ +echo "Killing img_pub sessions..." +tmux kill-session -t img_pub +sleep 1 +echo "Starting img_pub session" +tmux new -s img_pub -d '$CONDA_PREFIX/bin/python -m spot_rl.utils.img_publishers --listen' +sleep 3 +tmux ls diff --git a/spot_rl_experiments/spot_rl/launch/local_only.sh b/spot_rl_experiments/spot_rl/launch/local_only.sh new file mode 100644 index 00000000..3be27819 --- /dev/null +++ b/spot_rl_experiments/spot_rl/launch/local_only.sh @@ -0,0 +1,17 @@ +echo "Killing all tmux sessions..." +tmux kill-session -t roscore +tmux kill-session -t img_pub +tmux kill-session -t propio_pub +tmux kill-session -t tts_sub +sleep 1 +echo "Starting roscore tmux..." +tmux new -s roscore -d '$CONDA_PREFIX/bin/roscore' +echo "Starting other tmux nodes.." +tmux new -s img_pub -d '$CONDA_PREFIX/bin/python -m spot_rl.utils.img_publishers --local' +tmux new -s propio_pub -d '$CONDA_PREFIX/bin/python -m spot_rl.utils.helper_nodes --proprioception' +tmux new -s tts_sub -d '$CONDA_PREFIX/bin/python -m spot_rl.utils.helper_nodes --text-to-speech' +sleep 3 +tmux ls + +# This for running mask rcnn in img_publishers, which needs input images to be in grayscale +#tmux new -s img_pub -d '$CONDA_PREFIX/bin/python -m spot_rl.utils.img_publishers --local --bounding_box_detector mrcnn' diff --git a/spot_rl_experiments/spot_rl/llm/.gitignore b/spot_rl_experiments/spot_rl/llm/.gitignore new file mode 100644 index 00000000..53adab2e --- /dev/null +++ b/spot_rl_experiments/spot_rl/llm/.gitignore @@ -0,0 +1,130 @@ +# Byte-compiled / optimized / DLL files +**/outputs/** +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ diff --git a/spot_rl_experiments/spot_rl/llm/README.md b/spot_rl_experiments/spot_rl/llm/README.md new file mode 100644 index 00000000..563bcd0b --- /dev/null +++ b/spot_rl_experiments/spot_rl/llm/README.md @@ -0,0 +1,36 @@ +## Setup + +``` +pip install openai +pip install hydra-core --upgrade +``` + +Create an environmental variable `OPENAI_API_KEY` with your api key, +``` +os.environ["OPENAI_API_KEY"] = "...' +or export OPENAI_API_KEY='...' +``` + +## Usage + +``` +python main.py +instruction='Take the water from the table to the kitchen counter' verbose=false +``` +Or check `src/notebook.ipynb` + +## Config + +``` +├── conf +│ ├── config.yaml +│ ├── llm +│ │ └── openai.yaml +│ └── prompt +│ ├── rearrange_easy_few_shot.yaml +│ └── rearrange_easy_zero_shot.yaml +``` + +- llm/openai.yaml: contains the openai configuration: engine, tokens, temperature, etc. Can +modify it by running `python main.py llm.temperature=0.5` +- prompt/.: contains the prompts for the zero shot and few shot [defaults] tasks + diff --git a/spot_rl_experiments/spot_rl/llm/src/conf/config.yaml b/spot_rl_experiments/spot_rl/llm/src/conf/config.yaml new file mode 100644 index 00000000..f9e20fed --- /dev/null +++ b/spot_rl_experiments/spot_rl/llm/src/conf/config.yaml @@ -0,0 +1,5 @@ +defaults: + - llm: openai + - prompt: rearrange_easy_few_shot + +verbose: false \ No newline at end of file diff --git a/spot_rl_experiments/spot_rl/llm/src/conf/llm/openai.yaml b/spot_rl_experiments/spot_rl/llm/src/conf/llm/openai.yaml new file mode 100644 index 00000000..92db2579 --- /dev/null +++ b/spot_rl_experiments/spot_rl/llm/src/conf/llm/openai.yaml @@ -0,0 +1,41 @@ +# text-davinci-003, text-curie-001, text-babbage-001, text-ada-001 +engine: text-davinci-003 + +# The prompt to start the generation from. +prompt: '' + +# The maximum number of tokens to generate in the completion. +max_tokens: 100 + +# Sampling temperature between 0 and 2. Higher values will make the output more random, +# while lower values like 0.2 will make it more focused and deterministic. +temperature: .75 + +# An alternative to temperature, nucleus sampling. The model considers the results +# of the toklen with top_p probability mass. So 0.1 means only the tokens comprising +# the top 10% probability mass are considered. +top_p: 1 + +# Returns the best `n` out of `best_of` completions made on server side +n: 1 +best_of: 3 + +# Whether to stream back partial progress +stream: False + +# Include log-probabilities of the top `logprobs` tokens. +logprobs: 0 + +# up to 4 sequences that stop the generation +stop: '' + +# Dictionary that can modify the likelihood of specified tokens appearing in the completion. +# logit_bias: {} + +# Other params +frequency_penalty: 0 +presence_penalty: 0 +request_timeout: 20 + + + diff --git a/spot_rl_experiments/spot_rl/llm/src/conf/prompt/rearrange_easy_few_shot.yaml b/spot_rl_experiments/spot_rl/llm/src/conf/prompt/rearrange_easy_few_shot.yaml new file mode 100644 index 00000000..45a71aed --- /dev/null +++ b/spot_rl_experiments/spot_rl/llm/src/conf/prompt/rearrange_easy_few_shot.yaml @@ -0,0 +1,16 @@ +main_prompt: |- + You will solve a simple rearrangement task that requires you to Navigate to a given object + and Pick it up, and then Navigate to a given location and Place it there. Given an open instruction + that could be similar to "Go to the table and find the mug, and return the mug to box", you need to + return the solution sequence of actions: Nav(table), Pick(mug), Nav(box), Place(mug, box). +examples: |- + EXAMPLES: + Instruction: Go to table and find the mug, and return the mug to box + Solution: Nav(table), Pick(mug), Nav(box), Place(mug, box) + Instruction: Bring the apple from the kitchen counter to the table + Solution: Nav(kitchen counter), Pick(apple), Nav(table), Place(apple, table) +suffix: |- + Let's go! + Instruction: + Solution: +input_variable: rearrange_instruction \ No newline at end of file diff --git a/spot_rl_experiments/spot_rl/llm/src/conf/prompt/rearrange_easy_zero_shot.yaml b/spot_rl_experiments/spot_rl/llm/src/conf/prompt/rearrange_easy_zero_shot.yaml new file mode 100644 index 00000000..ded7f202 --- /dev/null +++ b/spot_rl_experiments/spot_rl/llm/src/conf/prompt/rearrange_easy_zero_shot.yaml @@ -0,0 +1,10 @@ +main_prompt: |- + You will solve a simple rearrangement task that requires you to Navigate to a given object + and Pick it up, and then Navigate to a given location and Place it there. Given an open instruction + that could be similar to "Go to the table and find the mug, and return the mug to box", you need to + return the solution sequence of actions: Nav(Table), Pick(Mug), Nav(Box), Place(Mug, Box). +suffix: |- + Let's go! + Instruction: + Solution: +input_variable: rearrange_instruction \ No newline at end of file diff --git a/spot_rl_experiments/spot_rl/llm/src/notebook.ipynb b/spot_rl_experiments/spot_rl/llm/src/notebook.ipynb new file mode 100644 index 00000000..730710f4 --- /dev/null +++ b/spot_rl_experiments/spot_rl/llm/src/notebook.ipynb @@ -0,0 +1,54 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from main import RearrangeEasyChain\n", + "from hydra import compose, initialize\n", + "with initialize(version_base='1.1', config_path='conf'):\n", + " conf = compose(config_name=\"config\") " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "chain = RearrangeEasyChain(conf)\n", + "chain.generate('Take a fruit to the kitchen counter')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "robot-llm", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.16" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "1c8265bd49bb56c0bee7cdaedd83419fbfa4c22e8925da46233a4da30e3ac686" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/spot_rl_experiments/spot_rl/llm/src/rearrange_llm.py b/spot_rl_experiments/spot_rl/llm/src/rearrange_llm.py new file mode 100644 index 00000000..d8260be8 --- /dev/null +++ b/spot_rl_experiments/spot_rl/llm/src/rearrange_llm.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import copy +import os + +import hydra +import openai +import regex as re +from omegaconf import DictConfig + + +class OpenAI: + def __init__(self, conf): + self.llm_conf = conf.llm + self.client = openai.Completion() + self._validate_conf() + self.verbose = conf.verbose + + def _validate_conf(self): + try: + openai.api_key = os.environ["OPENAI_API_KEY"] + except Exception: + raise ValueError("No API keys provided") + if self.llm_conf.stream: + raise ValueError("Streaming not supported") + if self.llm_conf.n > 1 and self.llm_conf.stream: + raise ValueError("Cannot stream results with n > 1") + if self.llm_conf.best_of > 1 and self.llm_conf.stream: + raise ValueError("Cannot stream results with best_of > 1") + + def generate(self, prompt): + params = copy.deepcopy(self.llm_conf) + params["prompt"] = prompt + if self.verbose: + print(f"Prompt: {prompt}") + return self.client.create(**params) + + +class RearrangeEasyChain: + def __init__(self, conf): + self.conf = conf + self._build_prompt() + self.llm = OpenAI(conf) + self.input_variable = f"<{self.conf.prompt.input_variable}>" + + def _build_prompt(self): + self.prompt = self.conf.prompt.main_prompt + if "examples" in self.conf.prompt: + self.prompt += f"\n{self.conf.prompt.examples}" + if "suffix" in self.conf.prompt: + self.prompt += f"\n{self.conf.prompt.suffix}" + + def generate(self, input): + prompt = self.prompt.replace(self.input_variable, input) + ans = self.llm.generate(prompt) + return ans + + def parse_instructions(self, input): + text = self.generate(input)["choices"][0]["text"] + matches = re.findall("\(.*?\)", text) # noqa + matches = [match.replace("(", "").replace(")", "") for match in matches] + nav_1, pick, nav_2, place = matches + place, nav_2 = place.split(",") + nav_1 = nav_1.strip() + pick = pick.strip() + nav_2 = nav_2.strip() + place = place.strip() + return nav_1, pick, nav_2, place + + +@hydra.main(config_name="config", config_path="conf") +def main(conf: DictConfig): + chain = RearrangeEasyChain(conf) + instruction = conf.instruction + ans = chain.generate(instruction) + print(ans["choices"][0]["text"]) + + +if __name__ == "__main__": + main() diff --git a/spot_rl_experiments/spot_rl/models/__init__.py b/spot_rl_experiments/spot_rl/models/__init__.py new file mode 100644 index 00000000..5221562a --- /dev/null +++ b/spot_rl_experiments/spot_rl/models/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from spot_rl.models.owlvit import OwlVit +from spot_rl.models.sentence_similarity import SentenceSimilarity diff --git a/spot_rl_experiments/spot_rl/models/owlvit.py b/spot_rl_experiments/spot_rl/models/owlvit.py new file mode 100644 index 00000000..6643565a --- /dev/null +++ b/spot_rl_experiments/spot_rl/models/owlvit.py @@ -0,0 +1,246 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +# mypy: ignore-errors +import argparse +import time + +import cv2 +import torch +from PIL import Image +from transformers import OwlViTForObjectDetection, OwlViTProcessor + + +class OwlVit: + def __init__(self, labels, score_threshold, show_img): + # self.device = torch.device('cpu') + self.device = ( + torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + ) + + self.model = OwlViTForObjectDetection.from_pretrained( + "google/owlvit-base-patch32" + ) + self.model.eval() + self.model.to(self.device) + + self.processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32") + + self.labels = labels + self.score_threshold = score_threshold + self.show_img = show_img + + def run_inference(self, img): + """ + img: an open cv image in (H, W, C) format + """ + # Process inputs + # img = img.to(self.device) + inputs = self.processor(text=self.labels, images=img, return_tensors="pt") + + # Target image sizes (height, width) to rescale box predictions [batch_size, 2] + # target_sizes = torch.Tensor([img.size[::-1]]) this is for PIL images + target_sizes = torch.Tensor([img.shape[:2]]).to(self.device) + inputs = inputs.to(self.device) + + # Inference + with torch.no_grad(): + outputs = self.model(**inputs) + + # Convert outputs (bounding boxes and class logits) to COCO API + results = self.processor.post_process( + outputs=outputs, target_sizes=target_sizes + ) + # img = img.to('cpu') + + if self.show_img: + self.show_img_with_overlaid_bounding_boxes(img, results) + + return self.get_most_confident_bounding_box_per_label(results) + + def run_inference_and_return_img(self, img): + """ + img: an open cv image in (H, W, C) format + """ + # img = img.to(self.device) + + inputs = self.processor(text=self.labels, images=img, return_tensors="pt") + target_sizes = torch.Tensor([img.shape[:2]]).to(self.device) + inputs = inputs.to(self.device) + # Inference + with torch.no_grad(): + outputs = self.model(**inputs) + + # Convert outputs (bounding boxes and class logits) to COCO API + results = self.processor.post_process( + outputs=outputs, target_sizes=target_sizes + ) + # img = img.to('cpu') + # if self.show_img: + # self.show_img_with_overlaid_bounding_boxes(img, results) + + return self.get_most_confident_bounding_box_per_label( + results + ), self.create_img_with_bounding_box(img, results) + + def show_img_with_overlaid_bounding_boxes(self, img, results): + img = self.create_img_with_bounding_box(img, results) + cv2.imshow("img", img) + cv2.waitKey(1) + + def get_bounding_boxes(self, results): + """ + Returns all bounding boxes with a score above the threshold + """ + boxes, scores, labels = ( + results[0]["boxes"], + results[0]["scores"], + results[0]["labels"], + ) + boxes = boxes.to("cpu") + labels = labels.to("cpu") + scores = scores.to("cpu") + + target_boxes = [] + for box, score, label in zip(boxes, scores, labels): + box = [round(i, 2) for i in box.tolist()] + if score >= self.score_threshold: + target_boxes.append([self.labels[0][label.item()], score.item(), box]) + + return target_boxes + + def get_most_confident_bounding_box(self, results): + """ + Returns the most confident bounding box + """ + boxes, scores, labels = ( + results[0]["boxes"], + results[0]["scores"], + results[0]["labels"], + ) + boxes = boxes.to("cpu") + labels = labels.to("cpu") + scores = scores.to("cpu") + + target_box = [] + target_score = -float("inf") + + for box, score, label in zip(boxes, scores, labels): + box = [round(i, 2) for i in box.tolist()] + if score >= self.score_threshold: + if score > target_score: + target_score = score + target_box = box + + if target_score == -float("inf"): + return None + else: + x1 = int(target_box[0]) + y1 = int(target_box[1]) + x2 = int(target_box[2]) + y2 = int(target_box[3]) + + print("location:", x1, y1, x2, y2) + return x1, y1, x2, y2 + + def get_most_confident_bounding_box_per_label(self, results): + """ + Returns the most confident bounding box for each label above the threshold + """ + boxes, scores, labels = ( + results[0]["boxes"], + results[0]["scores"], + results[0]["labels"], + ) + boxes = boxes.to("cpu") + labels = labels.to("cpu") + scores = scores.to("cpu") + + # Initialize dictionaries to store most confident bounding boxes and scores per label + target_boxes = {} + target_scores = {} + + for box, score, label in zip(boxes, scores, labels): + box = [round(i, 2) for i in box.tolist()] + if score >= self.score_threshold: + # If the current score is higher than the stored score for this label, update the target box and score + if ( + label.item() not in target_scores + or score > target_scores[label.item()] + ): + target_scores[label.item()] = score.item() + target_boxes[label.item()] = box + + # Format the output + result = [] + for label, box in target_boxes.items(): + x1 = int(box[0]) + y1 = int(box[1]) + x2 = int(box[2]) + y2 = int(box[3]) + + result.append( + [self.labels[0][label], target_scores[label], [x1, y1, x2, y2]] + ) + + return result + + def create_img_with_bounding_box(self, img, results): + """ + Returns an image with all bounding boxes avove the threshold overlaid + """ + + results = self.get_most_confident_bounding_box_per_label(results) + font = cv2.FONT_HERSHEY_SIMPLEX + + for label, score, box in results: + img = cv2.rectangle(img, box[:2], box[2:], (255, 0, 0), 5) + if box[3] + 25 > 768: + y = box[3] - 10 + else: + y = box[3] + 25 + img = cv2.putText( + img, label, (box[0], y), font, 1, (255, 0, 0), 2, cv2.LINE_AA + ) + + return img + + def update_label(self, labels): + self.labels = labels + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--file", + type=str, + default="/home/akshara/spot/spot_rl_experiments/spot_rl/grasp_visualizations/1650841878.2699108.png", + ) + parser.add_argument("--score_threshold", type=float, default=0.1) + parser.add_argument("--show_img", type=bool, default=True) + parser.add_argument( + "--labels", + type=list, + default=[ + [ + "lion plush", + "penguin plush", + "teddy bear", + "bear plush", + "caterpilar plush", + "ball plush", + "rubiks cube", + ] + ], + ) + args = parser.parse_args() + + file = args.file + img = cv2.imread(file) + + V = OwlVit(args.labels, args.score_threshold, args.show_img) + results = V.run_inference(img) + # Keep the window open for 10 seconds + time.sleep(10) diff --git a/spot_rl_experiments/spot_rl/models/sentence_similarity.py b/spot_rl_experiments/spot_rl/models/sentence_similarity.py new file mode 100644 index 00000000..8e661b2c --- /dev/null +++ b/spot_rl_experiments/spot_rl/models/sentence_similarity.py @@ -0,0 +1,78 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +import torch.nn.functional as F +from transformers import AutoModel, AutoTokenizer + + +class SentenceSimilarity: + def __init__(self): + # Load model from HuggingFace Hub + self.tokenizer = AutoTokenizer.from_pretrained( + "sentence-transformers/all-MiniLM-L6-v2" + ) + self.model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") + + def mean_pooling(self, model_output, attention_mask): + # Mean Pooling - Take attention mask into account for correct averaging + + token_embeddings = model_output[ + 0 + ] # First element of model_output contains all token embeddings + input_mask_expanded = ( + attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + ) + return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( + input_mask_expanded.sum(1), min=1e-9 + ) + + def get_similarity_two_sentences(self, a, b): + sentences = [a, b] + + # Tokenize sentences + encoded_input = self.tokenizer( + sentences, padding=True, truncation=True, return_tensors="pt" + ) + + # Compute token embeddings + with torch.no_grad(): + model_output = self.model(**encoded_input) + + # Perform pooling + sentence_embeddings = self.mean_pooling( + model_output, encoded_input["attention_mask"] + ) + + # Normalize embeddings + sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1) + + # compute cosine similarity between embeddings + cosine_scores = sentence_embeddings[0] @ sentence_embeddings[1].T + return cosine_scores + + def get_most_similar_in_list(self, query_word, list): + sentences = [query_word] + [word.replace("_", " ") for word in list] + encoded_input = self.tokenizer( + sentences, padding=True, truncation=True, return_tensors="pt" + ) + with torch.no_grad(): + model_output = self.model(**encoded_input) + + # Perform pooling + sentence_embeddings = self.mean_pooling( + model_output, encoded_input["attention_mask"] + ) + + # Normalize embeddings + sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1) + + # compute cosine similarity between embeddings + cosine_scores = sentence_embeddings[0] @ sentence_embeddings[1:].T + print( + f"word queried : {query_word} | word list : {list} | cosine scores : {cosine_scores}" + ) + + return list[torch.argmax(cosine_scores).item()] diff --git a/spot_rl_experiments/spot_rl/real_policy.py b/spot_rl_experiments/spot_rl/real_policy.py new file mode 100644 index 00000000..50d188e7 --- /dev/null +++ b/spot_rl_experiments/spot_rl/real_policy.py @@ -0,0 +1,310 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from collections import OrderedDict + +import numpy as np +import torch +from gym import spaces +from gym.spaces import Dict as SpaceDict +from habitat_baselines.rl.ppo.moe import NavGazeMixtureOfExpertsMask +from habitat_baselines.rl.ppo.policy import PointNavBaselinePolicy +from habitat_baselines.utils.common import batch_obs + + +# Turn numpy observations into torch tensors for consumption by policy +def to_tensor(v): + if torch.is_tensor(v): + return v + elif isinstance(v, np.ndarray): + return torch.from_numpy(v) + else: + return torch.tensor(v, dtype=torch.float) + + +class RealPolicy: + def __init__( + self, + checkpoint_path, + observation_space, + action_space, + device, + policy_class=PointNavBaselinePolicy, + ): + print("Loading policy...") + self.device = torch.device(device) + if isinstance(checkpoint_path, str): + checkpoint = torch.load(checkpoint_path, map_location="cpu") + else: + checkpoint = checkpoint_path + config = checkpoint["config"] + + """ Disable observation transforms for real world experiments """ + config.defrost() + config.RL.POLICY.OBS_TRANSFORMS.ENABLED_TRANSFORMS = [] + config.freeze() + config.RL.POLICY["init"] = False + + self.policy = policy_class.from_config( + config=config, + observation_space=observation_space, + action_space=action_space, + ) + print("Actor-critic architecture:", self.policy) + # Move it to the device + self.policy.to(self.device) + + # Load trained weights into the policy + self.policy.load_state_dict( + {k[len("actor_critic.") :]: v for k, v in checkpoint["state_dict"].items()} + ) + + self.prev_actions = None + self.test_recurrent_hidden_states = None + self.not_done_masks = None + self.config = config + self.num_actions = action_space.shape[0] + self.reset_ran = False + print("Policy loaded.") + + def reset(self): + self.reset_ran = True + self.test_recurrent_hidden_states = torch.zeros( + 1, # The number of environments. Just one for real world. + self.policy.net.num_recurrent_layers, + self.config.RL.PPO.hidden_size, + device=self.device, + ) + + # We start an episode with 'done' being True (0 for 'not_done') + self.not_done_masks = torch.zeros(1, 1, dtype=torch.bool, device=self.device) + self.prev_actions = torch.zeros(1, self.num_actions, device=self.device) + + def act(self, observations): + assert self.reset_ran, "You need to call .reset() on the policy first." + batch = batch_obs([observations], device=self.device) + with torch.no_grad(): + _, actions, _, self.test_recurrent_hidden_states = self.policy.act( + batch, + self.test_recurrent_hidden_states, + self.prev_actions, + self.not_done_masks, + deterministic=True, + actions_only=True, + ) + self.prev_actions.copy_(actions) + self.not_done_masks = torch.ones(1, 1, dtype=torch.bool, device=self.device) + + # GPU/CPU torch tensor -> numpy + actions = actions.squeeze().cpu().numpy() + + return actions + + +class GazePolicy(RealPolicy): + def __init__(self, checkpoint_path, device): + observation_space = SpaceDict( + { + "arm_depth": spaces.Box( + low=0.0, high=1.0, shape=(240, 228, 1), dtype=np.float32 + ), + "arm_depth_bbox": spaces.Box( + low=0.0, high=1.0, shape=(240, 228, 1), dtype=np.float32 + ), + "joint": spaces.Box(low=0.0, high=1.0, shape=(4,), dtype=np.float32), + "is_holding": spaces.Box( + low=0.0, high=1.0, shape=(1,), dtype=np.float32 + ), + } + ) + action_space = spaces.Box(-1.0, 1.0, (4,)) + super().__init__(checkpoint_path, observation_space, action_space, device) + + +class PlacePolicy(RealPolicy): + def __init__(self, checkpoint_path, device): + observation_space = SpaceDict( + { + "joint": spaces.Box(low=0.0, high=1.0, shape=(4,), dtype=np.float32), + "obj_start_sensor": spaces.Box( + low=0.0, high=1.0, shape=(3,), dtype=np.float32 + ), + } + ) + action_space = spaces.Box(-1.0, 1.0, (4,)) + super().__init__(checkpoint_path, observation_space, action_space, device) + + +class NavPolicy(RealPolicy): + def __init__(self, checkpoint_path, device): + observation_space = SpaceDict( + { + "spot_left_depth": spaces.Box( + low=0.0, high=1.0, shape=(212, 120, 1), dtype=np.float32 + ), + "spot_right_depth": spaces.Box( + low=0.0, high=1.0, shape=(212, 120, 1), dtype=np.float32 + ), + "goal_heading": spaces.Box( + low=-np.pi, high=np.pi, shape=(1,), dtype=np.float32 + ), + "target_point_goal_gps_and_compass_sensor": spaces.Box( + low=np.finfo(np.float32).min, + high=np.finfo(np.float32).max, + shape=(2,), + dtype=np.float32, + ), + } + ) + # Linear, angular, and horizontal velocity (in that order) + action_space = spaces.Box(-1.0, 1.0, (2,)) + super().__init__(checkpoint_path, observation_space, action_space, device) + + +class MixerPolicy(RealPolicy): + def __init__( + self, + mixer_checkpoint_path, + nav_checkpoint_path, + gaze_checkpoint_path, + place_checkpoint_path, + device, + ): + observation_space = SpaceDict( + { + "spot_left_depth": spaces.Box( + low=0.0, high=1.0, shape=(212, 120, 1), dtype=np.float32 + ), + "spot_right_depth": spaces.Box( + low=0.0, high=1.0, shape=(212, 120, 1), dtype=np.float32 + ), + "goal_heading": spaces.Box( + low=-np.pi, high=np.pi, shape=(1,), dtype=np.float32 + ), + "target_point_goal_gps_and_compass_sensor": spaces.Box( + low=np.finfo(np.float32).min, + high=np.finfo(np.float32).max, + shape=(2,), + dtype=np.float32, + ), + "arm_depth": spaces.Box( + low=0.0, high=1.0, shape=(240, 228, 1), dtype=np.float32 + ), + "arm_depth_bbox": spaces.Box( + low=0.0, high=1.0, shape=(240, 228, 1), dtype=np.float32 + ), + "joint": spaces.Box(low=0.0, high=1.0, shape=(4,), dtype=np.float32), + "is_holding": spaces.Box( + low=0.0, high=1.0, shape=(1,), dtype=np.float32 + ), + "obj_start_sensor": spaces.Box( + low=0.0, high=1.0, shape=(3,), dtype=np.float32 + ), + "visual_features": spaces.Box( + low=0.0, high=1.0, shape=(1024,), dtype=np.float32 + ), + } + ) + checkpoint = torch.load(mixer_checkpoint_path, map_location="cpu") + checkpoint["config"].RL.POLICY["nav_checkpoint_path"] = nav_checkpoint_path + checkpoint["config"].RL.POLICY["gaze_checkpoint_path"] = gaze_checkpoint_path + checkpoint["config"].RL.POLICY["place_checkpoint_path"] = place_checkpoint_path + # checkpoint["config"].RL.POLICY["use_residuals"] = False + checkpoint["config"]["NUM_ENVIRONMENTS"] = 1 + action_space = spaces.Box(-1.0, 1.0, (6 + 3,)) + super().__init__( + checkpoint, + observation_space, + action_space, + device, + policy_class=NavGazeMixtureOfExpertsMask, + ) + self.not_done = torch.zeros(1, 1, dtype=torch.bool, device=self.device) + self.moe_actions = None + self.policy.deterministic_nav = True + self.policy.deterministic_gaze = True + self.policy.deterministic_place = True + self.nav_silence_only = True + self.test_recurrent_hidden_states = torch.zeros( + self.config.NUM_ENVIRONMENTS, + 1, + 512 * 3, + device=self.device, + ) + + def reset(self): + self.not_done = torch.zeros(1, 1, dtype=torch.bool, device=self.device) + self.test_recurrent_hidden_states = torch.zeros( + self.config.NUM_ENVIRONMENTS, + 1, + 512 * 3, + device=self.device, + ) + + def act(self, observations, expert=None): + transformed_obs = self.policy.transform_obs([observations], self.not_done) + batch = batch_obs(transformed_obs, device=self.device) + with torch.no_grad(): + _, actions, _, self.test_recurrent_hidden_states = self.policy.act( + batch, + self.test_recurrent_hidden_states, + None, + self.not_done, + deterministic=False, + # deterministic=True, + actions_only=True, + ) + + # GPU/CPU torch tensor -> numpy + self.not_done = torch.ones(1, 1, dtype=torch.bool, device=self.device) + actions = actions.squeeze().cpu().numpy() + + activated_experts = [] + corrective_actions = OrderedDict() + corrective_actions["arm"] = actions[:4] + corrective_actions["base"] = actions[4:6] + if actions[-3] > 0: + activated_experts.append("nav") + corrective_actions.pop("base") + self.nav_silence_only = True + else: + self.nav_silence_only = False + if actions[-2] > 0: + activated_experts.append("gaze") + corrective_actions.pop("arm") + if actions[-1] > 0: + activated_experts.append("place") + corrective_actions.pop("arm") + corrective_actions_list = [] + for v in corrective_actions.values(): + for vv in v: + corrective_actions_list.append(f"{vv:.3f}") + print( + f"gater: {', '.join(activated_experts)}\t" + f"corrective: {', '.join(corrective_actions_list)}" + ) + + self.moe_actions = actions + action_dict = self.policy.action_to_dict(actions, 0, use_residuals=False) + step_action = action_dict["action"]["action"].numpy() + arm_action, base_action = np.split(step_action, [4]) + + return base_action, arm_action + + +if __name__ == "__main__": + gaze_policy = GazePolicy( + "weights/bbox_mask_5thresh_autograsp_shortrange_seed1_36.pth", + device="cpu", + ) + gaze_policy.reset() + observations = { + "arm_depth": np.zeros([240, 320, 1], dtype=np.float32), + "arm_depth_bbox": np.zeros([240, 320, 1], dtype=np.float32), + "joint": np.zeros(4, dtype=np.float32), + "is_holding": np.zeros(1, dtype=np.float32), + } + actions = gaze_policy.act(observations) + print("actions:", actions) diff --git a/spot_rl_experiments/spot_rl/ros_img_vis.py b/spot_rl_experiments/spot_rl/ros_img_vis.py new file mode 100644 index 00000000..b5cc32a6 --- /dev/null +++ b/spot_rl_experiments/spot_rl/ros_img_vis.py @@ -0,0 +1,279 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import argparse +import os +import os.path as osp +import time +from collections import deque + +import cv2 +import numpy as np +import rospy +import tqdm +from spot_rl.utils.robot_subscriber import SpotRobotSubscriberMixin +from spot_rl.utils.utils import ros_topics as rt +from spot_wrapper.utils import resize_to_tallest + +RAW_IMG_TOPICS = [rt.HEAD_DEPTH, rt.HAND_DEPTH, rt.HAND_RGB] + +PROCESSED_IMG_TOPICS = [ + rt.FILTERED_HEAD_DEPTH, + rt.FILTERED_HAND_DEPTH, + rt.MASK_RCNN_VIZ_TOPIC, +] + +FOUR_CC = cv2.VideoWriter_fourcc(*"MP4V") +FPS = 30 + + +class VisualizerMixin: + def __init__(self, headless=False, *args, **kwargs): + super().__init__(*args, **kwargs) + self.recording = False + self.frames = [] + self.headless = headless + self.curr_video_time = time.time() + self.out_path = None + self.video = None + self.dim = None + self.new_video_started = False + self.named_window = "ROS Spot Images" + + def generate_composite(self): + raise NotImplementedError + + @staticmethod + def overlay_text(img, text, color=(0, 0, 255), size=2.0, thickness=4): + viz_img = img.copy() + line, font, font_size, font_thickness = ( + text, + cv2.FONT_HERSHEY_SIMPLEX, + size, + thickness, + ) + + height, width = img.shape[:2] + y0, dy = 100, 100 + for i, line in enumerate(text.split("\n")): + text_width, text_height = cv2.getTextSize( + line, font, font_size, font_thickness + )[0][:2] + + x = (width - text_width) // 2 + # y = (height - text_height) // 2 + y = y0 + i * dy + cv2.putText( + viz_img, + line, + (x, y), + font, + font_size, + color, + font_thickness, + lineType=cv2.LINE_AA, + ) + return viz_img + + def initializeWindow(self): + cv2.namedWindow(self.named_window, cv2.WINDOW_NORMAL) + + def vis_imgs(self): + # Skip if no messages were updated + currently_saving = not self.recording and self.frames + img = self.generate_composite() if not currently_saving else None + if not self.headless: + if img is not None: + if self.recording: + viz_img = self.overlay_text(img, "RECORDING IS ON!") + cv2.imshow(self.named_window, viz_img) + else: + cv2.imshow(self.named_window, img) + + key = cv2.waitKey(1) + if key != -1: + if ord("r") == key and not currently_saving: + self.recording = not self.recording + elif ord("q") == key: + exit() + + if img is not None: + self.dim = img.shape[:2] + + # Video recording + if self.recording: + self.frames.append(time.time()) + if self.video is None: + height, width = img.shape[:2] + self.out_path = f"{time.time()}.mp4" + self.video = cv2.VideoWriter( + self.out_path, FOUR_CC, FPS, (width, height) + ) + self.video.write(img) + + if currently_saving and not self.recording: + self.save_video() + + def save_video(self): + if self.video is None: + return + # Close window while we work + cv2.destroyAllWindows() + + # Save current buffer + self.video.release() + old_video = cv2.VideoCapture(self.out_path) + ret, img = old_video.read() + + # Re-make video with correct timing + height, width = self.dim + self.new_video_started = True + new_video = cv2.VideoWriter( + self.out_path.replace(".mp4", "_final.mp4"), + FOUR_CC, + FPS, + (width, height), + ) + curr_video_time = self.frames[0] + for idx, timestamp in enumerate(tqdm.tqdm(self.frames)): + if not ret: + break + if idx + 1 >= len(self.frames): + new_video.write(img) + else: + next_timestamp = self.frames[idx + 1] + while curr_video_time < next_timestamp: + new_video.write(img) + curr_video_time += 1 / FPS + ret, img = old_video.read() + + new_video.release() + os.remove(self.out_path) + self.video, self.out_path, self.frames = None, None, [] + self.new_video_started = False + + def delete_videos(self): + for i in [self.out_path, self.out_path.replace(".mp4", "_final.mp4")]: + if osp.isfile(i): + os.remove(i) + + +class SpotRosVisualizer(VisualizerMixin, SpotRobotSubscriberMixin): + node_name = "SpotRosVisualizer" + no_raw = False + proprioception = False + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.last_seen = {topic: time.time() for topic in self.msgs.keys()} + self.fps = {topic: deque(maxlen=10) for topic in self.msgs.keys()} + + def generate_composite(self): + if not any(self.updated.values()): + # No imgs were refreshed. Skip. + return None + + refreshed_topics = [k for k, v in self.updated.items() if v] + + # Gather latest images + raw_msgs = [self.msgs[i] for i in RAW_IMG_TOPICS] + processed_msgs = [self.msgs[i] for i in PROCESSED_IMG_TOPICS] + + raw_imgs = [self.msg_to_cv2(i) for i in raw_msgs if i is not None] + + # Replace any Nones with black images if raw version exists. We (safely) assume + # here that there is no processed image anyway if the raw image does not exist. + processed_imgs = [] + for idx, raw_msg in enumerate(raw_msgs): + if processed_msgs[idx] is not None: + processed_imgs.append(self.msg_to_cv2(processed_msgs[idx])) + elif processed_msgs[idx] is None and raw_msg is not None: + processed_imgs.append(np.zeros_like(raw_imgs[idx])) + + # Crop gripper images + if raw_msgs[1] is not None: + for imgs in [raw_imgs, processed_imgs]: + imgs[1] = imgs[1][:, 124:-60] + + img = np.vstack( + [ + resize_to_tallest(bgrify_grayscale_imgs(i), hstack=True) + for i in [raw_imgs, processed_imgs] + ] + ) + + # Add Pick receptacle, Object, Place receptacle information on the side + pck = rospy.get_param("/viz_pick", "None") + obj = rospy.get_param("/viz_object", "None") + plc = rospy.get_param("/viz_place", "None") + information_string = ( + "Pick from:\n" + + pck + + "\n\nObject Target:\n" + + obj + + "\n\nPlace to:\n" + + plc + ) + display_img = 255 * np.ones( + (img.shape[0], int(img.shape[1] / 4), img.shape[2]), dtype=np.uint8 + ) + display_img = self.overlay_text( + display_img, information_string, color=(255, 0, 0), size=1.5, thickness=4 + ) + img = resize_to_tallest([img, display_img], hstack=True) + + for topic in refreshed_topics: + curr_time = time.time() + self.updated[topic] = False + self.fps[topic].append(1 / (curr_time - self.last_seen[topic])) + self.last_seen[topic] = curr_time + + all_topics = RAW_IMG_TOPICS + PROCESSED_IMG_TOPICS + print(" ".join([f"{k[1:]}: {np.mean(self.fps[k]):.2f}" for k in all_topics])) + + return img + + +def bgrify_grayscale_imgs(imgs): + return [ + cv2.cvtColor(i, cv2.COLOR_GRAY2BGR) if i.ndim == 2 or i.shape[-1] == 1 else i + for i in imgs + ] + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--headless", action="store_true") + parser.add_argument("--record", action="store_true") + args = parser.parse_args() + + srv = None + try: + srv = SpotRosVisualizer(headless=args.headless) + srv.initializeWindow() + if args.record: + srv.recording = True + while not rospy.is_shutdown(): + srv.vis_imgs() + except Exception as e: + print("Ending script.") + if not args.headless: + cv2.destroyAllWindows() + if srv is not None: + try: + if srv.new_video_started: + print("Deleting unfinished videos.") + srv.delete_videos() + else: + srv.save_video() + except Exception: + print("Deleting unfinished videos") + srv.delete_videos() + exit() + raise e + + +if __name__ == "__main__": + main() diff --git a/spot_rl_experiments/spot_rl/spot_ros_node.py b/spot_rl_experiments/spot_rl/spot_ros_node.py new file mode 100644 index 00000000..ebf436f2 --- /dev/null +++ b/spot_rl_experiments/spot_rl/spot_ros_node.py @@ -0,0 +1,331 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +# import argparse +# import time +# +# import blosc +# import cv2 +# import numpy as np +# import rospy +# from cv_bridge import CvBridge +# from sensor_msgs.msg import CompressedImage, Image +# from spot_wrapper.spot import Spot, SpotCamIds, image_response_to_cv2, scale_depth_img +# from spot_wrapper.utils import say +# from std_msgs.msg import ( +# ByteMultiArray, +# Float32MultiArray, +# MultiArrayDimension, +# MultiArrayLayout, +# String, +# ) +# +# from spot_rl.utils.depth_map_utils import fill_in_multiscale +# +# ROBOT_VEL_TOPIC = "/spot_cmd_velocities" +# MASK_RCNN_VIZ_TOPIC = "/mask_rcnn_visualizations" +# COMPRESSED_IMAGES_TOPIC = "/spot_cams/compressed_images" +# ROBOT_STATE_TOPIC = "/robot_state" +# TEXT_TO_SPEECH_TOPIC = "/text_to_speech" +# SRC2MSG = { +# SpotCamIds.FRONTLEFT_DEPTH: ByteMultiArray, +# SpotCamIds.FRONTRIGHT_DEPTH: ByteMultiArray, +# SpotCamIds.HAND_DEPTH_IN_HAND_COLOR_FRAME: ByteMultiArray, +# SpotCamIds.HAND_COLOR: CompressedImage, +# } +# MAX_DEPTH = 3.5 +# MAX_HAND_DEPTH = 1.7 +# +# NAV_POSE_BUFFER_LEN = 3 +# +# +# class SpotRosPublisher: +# def __init__(self, spot): +# rospy.init_node("spot_ros_node", disable_signals=True) +# self.spot = spot +# +# # For generating Image ROS msgs +# self.cv_bridge = CvBridge() +# +# # Instantiate raw image publishers +# self.sources = list(SRC2MSG.keys()) +# self.img_pub = rospy.Publisher( +# COMPRESSED_IMAGES_TOPIC, ByteMultiArray, queue_size=1, tcp_nodelay=True +# ) +# +# self.last_publish = time.time() +# rospy.loginfo("[spot_ros_node]: Publishing has started.") +# +# def publish_msgs(self): +# st = time.time() +# if st < self.last_publish + 1 / 8: +# return +# +# image_responses = self.spot.get_image_responses(self.sources, quality=100) +# retrieval_time = time.time() - st +# # Publish raw images +# src2details = {} +# for src, response in zip(self.sources, image_responses): +# img = image_response_to_cv2(response) +# +# if "depth" in src: +# # Rescale depth images here +# if src == SpotCamIds.HAND_DEPTH_IN_HAND_COLOR_FRAME: +# max_depth = MAX_HAND_DEPTH +# else: +# max_depth = MAX_DEPTH +# img = scale_depth_img(img, max_depth=max_depth) +# img = np.uint8(img * 255.0) +# img_bytes = blosc.pack_array( +# img, cname="zstd", clevel=3, shuffle=blosc.NOSHUFFLE +# ) +# else: +# # RGB should be JPEG compressed instead of using blosc +# img_bytes = np.array(cv2.imencode(".jpg", img)[1]) +# img_bytes = (img_bytes.astype(int) - 128).astype(np.int8) +# src2details[src] = { +# "dims": MultiArrayDimension(label=src, size=len(img_bytes)), +# "bytes": img_bytes, +# } +# +# depth_bytes = b"" +# rgb_bytes = [] +# depth_dims = [] +# rgb_dims = [] +# for k, v in src2details.items(): +# if "depth" in k: +# depth_bytes += v["bytes"] +# depth_dims.append(v["dims"]) +# else: +# rgb_bytes.append(v["bytes"]) +# rgb_dims.append(v["dims"]) +# depth_bytes = np.frombuffer(depth_bytes, dtype=np.uint8) +# depth_bytes = depth_bytes.astype(int) - 128 +# bytes_data = np.concatenate([depth_bytes, *rgb_bytes]) +# timestamp = int(str(int(st * 1000))[-6:]) +# timestamp_dim = MultiArrayDimension(label="", size=timestamp) +# dims = depth_dims + rgb_dims + [timestamp_dim] +# +# msg = ByteMultiArray(layout=MultiArrayLayout(dim=dims), data=bytes_data) +# self.img_pub.publish(msg) +# +# rospy.loginfo( +# f"[spot_ros_node]: Image retrieval / publish time: " +# f"{1 / retrieval_time:.4f} / {1 / (time.time() - self.last_publish):.4f} Hz" +# ) +# self.last_publish = time.time() +# +# +# class SpotRosSubscriber: +# def __init__(self, node_name, is_blind=False, proprioception=True): +# rospy.init_node(node_name, disable_signals=True) +# +# # For generating Image ROS msgs +# self.cv_bridge = CvBridge() +# if not is_blind: +# +# # Instantiate subscribers +# rospy.Subscriber( +# COMPRESSED_IMAGES_TOPIC, +# ByteMultiArray, +# self.compressed_callback, +# queue_size=1, +# buff_size=2 ** 30, +# ) +# rospy.Subscriber( +# MASK_RCNN_VIZ_TOPIC, +# Image, +# self.viz_callback, +# queue_size=1, +# buff_size=2 ** 30, +# ) +# +# if proprioception: +# rospy.Subscriber( +# ROBOT_STATE_TOPIC, +# Float32MultiArray, +# self.robot_state_callback, +# queue_size=1, +# ) +# +# # Msg holders +# self.compressed_imgs_msg = None +# self.front_depth = None +# self.hand_depth = None +# self.hand_rgb = None +# self.det = None +# self.x = 0.0 +# self.y = 0.0 +# self.yaw = 0.0 +# self.current_arm_pose = None +# self.link_wr1_position, self.link_wr1_rotation = None, None +# self.lock = False +# +# self.updated = False +# rospy.loginfo(f"[{node_name}]: Subscribing has started.") +# self.last_compressed_subscribe = time.time() +# +# def viz_callback(self, msg): +# self.det = msg +# self.updated = True +# +# def compressed_callback(self, msg): +# if self.lock: +# return +# msg.layout.dim, timestamp_dim = msg.layout.dim[:-1], msg.layout.dim[-1] +# latency = (int(str(int(time.time() * 1000))[-6:]) - timestamp_dim.size) / 1000 +# print("Latency: ", latency) +# # if latency > 0.5: +# # return +# self.compressed_imgs_msg = msg +# self.updated = True +# self.last_compressed_subscribe = time.time() +# +# def uncompress_imgs(self): +# assert self.compressed_imgs_msg is not None, "No compressed imgs received!" +# self.lock = True +# byte_data = (np.array(self.compressed_imgs_msg.data) + 128).astype(np.uint8) +# size_and_labels = [ +# (int(dim.size), str(dim.label)) +# for dim in self.compressed_imgs_msg.layout.dim +# ] +# self.lock = False +# self.hand_depth, self.hand_rgb, self.front_depth = uncompress_img_msg( +# byte_data, size_and_labels +# ) +# +# def robot_state_callback(self, msg): +# self.x, self.y, self.yaw = msg.data[:3] +# self.current_arm_pose = msg.data[3:-7] +# self.link_wr1_position, self.link_wr1_rotation = ( +# msg.data[-7:][:3], +# msg.data[-7:][3:], +# ) +# +# @staticmethod +# def filter_depth(depth_img, max_depth, whiten_black=True): +# filtered_depth_img = ( +# fill_in_multiscale(depth_img.astype(np.float32) * (max_depth / 255.0))[0] +# * (255.0 / max_depth) +# ).astype(np.uint8) +# # Recover pixels that weren't black before but were turned black by filtering +# recovery_pixels = np.logical_and(depth_img != 0, filtered_depth_img == 0) +# filtered_depth_img[recovery_pixels] = depth_img[recovery_pixels] +# if whiten_black: +# filtered_depth_img[filtered_depth_img == 0] = 255 +# return filtered_depth_img +# +# +# class SpotRosProprioceptionPublisher: +# def __init__(self, spot): +# rospy.init_node("spot_ros_proprioception_node", disable_signals=True) +# self.spot = spot +# +# # Instantiate filtered image publishers +# self.pub = rospy.Publisher(ROBOT_STATE_TOPIC, Float32MultiArray, queue_size=1) +# self.last_publish = time.time() +# rospy.loginfo("[spot_ros_proprioception_node]: Publishing has started.") +# +# self.nav_pose_buff = None +# self.buff_idx = 0 +# +# def publish_msgs(self): +# st = time.time() +# robot_state = self.spot.get_robot_state() +# msg = Float32MultiArray() +# xy_yaw = self.spot.get_xy_yaw(robot_state=robot_state, use_boot_origin=True) +# if self.nav_pose_buff is None: +# self.nav_pose_buff = np.tile(xy_yaw, [NAV_POSE_BUFFER_LEN, 1]) +# else: +# self.nav_pose_buff[self.buff_idx] = xy_yaw +# self.buff_idx = (self.buff_idx + 1) % NAV_POSE_BUFFER_LEN +# xy_yaw = np.mean(self.nav_pose_buff, axis=0) +# +# joints = self.spot.get_arm_proprioception(robot_state=robot_state).values() +# +# position, rotation = self.spot.get_base_transform_to("link_wr1") +# gripper_transform = [position.x, position.y, position.z] + [ +# rotation.x, +# rotation.y, +# rotation.z, +# rotation.w, +# ] +# +# msg.data = np.array( +# list(xy_yaw) + [j.position.value for j in joints] + gripper_transform, +# dtype=np.float32, +# ) +# +# # Limit publishing to 10 Hz max +# if time.time() - self.last_publish > 1 / 10: +# self.pub.publish(msg) +# rospy.loginfo( +# f"[spot_ros_proprioception_node]: " +# "Proprioception retrieval / publish time: " +# f"{1/(time.time() - st):.4f} / " +# f"{1/(time.time() - self.last_publish):.4f} Hz" +# ) +# self.last_publish = time.time() +# +# +# def uncompress_img_msg( +# byte_data, size_and_labels, head=True, gripper_depth=True, gripper_rgb=True +# ): +# start = 0 +# eyes = {} +# hand_depth, hand_rgb, front_depth = None, None, None +# for size, label in size_and_labels: +# end = start + size +# if "depth" in label: +# try: +# if head and label == SpotCamIds.FRONTLEFT_DEPTH: +# eyes["left"] = blosc.unpack_array(byte_data[start:end].tobytes()) +# elif head and label == SpotCamIds.FRONTRIGHT_DEPTH: +# eyes["right"] = blosc.unpack_array(byte_data[start:end].tobytes()) +# elif ( +# gripper_depth and label == SpotCamIds.HAND_DEPTH_IN_HAND_COLOR_FRAME +# ): +# hand_depth = blosc.unpack_array(byte_data[start:end].tobytes()) +# except: +# pass +# elif gripper_rgb and "color" in label: +# rgb_bytes = byte_data[start:end] +# hand_rgb = cv2.imdecode(rgb_bytes, cv2.IMREAD_COLOR) +# start += size +# +# if len(eyes) == 2: +# front_depth = np.hstack([eyes["right"], eyes["left"]]) +# +# return hand_depth, hand_rgb, front_depth +# +# +# def main(): +# parser = argparse.ArgumentParser() +# parser.add_argument("-p", "--proprioception", action="store_true") +# parser.add_argument("-t", "--text-to-speech", action="store_true") +# args = parser.parse_args() +# +# if args.text_to_speech: +# tts_callback = lambda msg: say(msg.data) +# rospy.init_node("spot_ros_tts_node", disable_signals=True) +# rospy.Subscriber(TEXT_TO_SPEECH_TOPIC, String, tts_callback, queue_size=1) +# rospy.loginfo("[spot_ros_tts_node]: Listening for text to dictate.") +# rospy.spin() +# else: +# if args.proprioception: +# name = "spot_ros_proprioception_node" +# cls = SpotRosProprioceptionPublisher +# else: +# name = "spot_ros_node" +# cls = SpotRosPublisher +# +# spot = Spot(name) +# srn = cls(spot) +# while not rospy.is_shutdown(): +# srn.publish_msgs() +# +# +# if __name__ == "__main__": +# main() diff --git a/spot_rl_experiments/spot_rl/utils/autodock.py b/spot_rl_experiments/spot_rl/utils/autodock.py new file mode 100644 index 00000000..6f563935 --- /dev/null +++ b/spot_rl_experiments/spot_rl/utils/autodock.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import subprocess +import sys + +subprocess.check_call( + f"{sys.executable} -m spot_rl.envs.nav_env -w dock -d", shell=True +) diff --git a/spot_rl_experiments/spot_rl/utils/depth_map_utils.py b/spot_rl_experiments/spot_rl/utils/depth_map_utils.py new file mode 100644 index 00000000..1e06e004 --- /dev/null +++ b/spot_rl_experiments/spot_rl/utils/depth_map_utils.py @@ -0,0 +1,324 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import collections + +import cv2 +import numpy as np + +# Full kernels +FULL_KERNEL_3 = np.ones((3, 3), np.uint8) +FULL_KERNEL_5 = np.ones((5, 5), np.uint8) +FULL_KERNEL_7 = np.ones((7, 7), np.uint8) +FULL_KERNEL_9 = np.ones((9, 9), np.uint8) +FULL_KERNEL_31 = np.ones((31, 31), np.uint8) + +# 3x3 cross kernel +CROSS_KERNEL_3 = np.asarray( + [ + [0, 1, 0], + [1, 1, 1], + [0, 1, 0], + ], + dtype=np.uint8, +) + +# 5x5 cross kernel +CROSS_KERNEL_5 = np.asarray( + [ + [0, 0, 1, 0, 0], + [0, 0, 1, 0, 0], + [1, 1, 1, 1, 1], + [0, 0, 1, 0, 0], + [0, 0, 1, 0, 0], + ], + dtype=np.uint8, +) + +# 5x5 diamond kernel +DIAMOND_KERNEL_5 = np.array( + [ + [0, 0, 1, 0, 0], + [0, 1, 1, 1, 0], + [1, 1, 1, 1, 1], + [0, 1, 1, 1, 0], + [0, 0, 1, 0, 0], + ], + dtype=np.uint8, +) + +# 7x7 cross kernel +CROSS_KERNEL_7 = np.asarray( + [ + [0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1], + [0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0], + ], + dtype=np.uint8, +) + +# 7x7 diamond kernel +DIAMOND_KERNEL_7 = np.asarray( + [ + [0, 0, 0, 1, 0, 0, 0], + [0, 0, 1, 1, 1, 0, 0], + [0, 1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1, 1], + [0, 1, 1, 1, 1, 1, 0], + [0, 0, 1, 1, 1, 0, 0], + [0, 0, 0, 1, 0, 0, 0], + ], + dtype=np.uint8, +) + + +def filter_depth(depth_img, max_depth, whiten_black=True): + filtered_depth_img = ( + fill_in_multiscale(depth_img.astype(np.float32) * (max_depth / 255.0))[0] + * (255.0 / max_depth) + ).astype(np.uint8) + # Recover pixels that weren't black before but were turned black by filtering + recovery_pixels = np.logical_and(depth_img != 0, filtered_depth_img == 0) + filtered_depth_img[recovery_pixels] = depth_img[recovery_pixels] + if whiten_black: + filtered_depth_img[filtered_depth_img == 0] = 255 + return filtered_depth_img + + +def fill_in_fast( + depth_map, + max_depth=100.0, + custom_kernel=DIAMOND_KERNEL_5, + extrapolate=False, + blur_type="bilateral", +): + """Fast, in-place depth completion. + + Args: + depth_map: projected depths + max_depth: max depth value for inversion + custom_kernel: kernel to apply initial dilation + extrapolate: whether to extrapolate by extending depths to top of + the frame, and applying a 31x31 full kernel dilation + blur_type: + 'bilateral' - preserves local structure (recommended) + 'gaussian' - provides lower RMSE + + Returns: + depth_map: dense depth map + """ + + # Invert + valid_pixels = depth_map > 0.1 + depth_map[valid_pixels] = max_depth - depth_map[valid_pixels] + + # Dilate + depth_map = cv2.dilate(depth_map, custom_kernel) + + # Hole closing + depth_map = cv2.morphologyEx(depth_map, cv2.MORPH_CLOSE, FULL_KERNEL_5) + + # Fill empty spaces with dilated values + empty_pixels = depth_map < 0.1 + dilated = cv2.dilate(depth_map, FULL_KERNEL_7) + depth_map[empty_pixels] = dilated[empty_pixels] + + # Extend highest pixel to top of image + if extrapolate: + top_row_pixels = np.argmax(depth_map > 0.1, axis=0) + top_pixel_values = depth_map[top_row_pixels, range(depth_map.shape[1])] + + for pixel_col_idx in range(depth_map.shape[1]): + depth_map[ + 0 : top_row_pixels[pixel_col_idx], pixel_col_idx + ] = top_pixel_values[pixel_col_idx] + + # Large Fill + empty_pixels = depth_map < 0.1 + dilated = cv2.dilate(depth_map, FULL_KERNEL_31) + depth_map[empty_pixels] = dilated[empty_pixels] + + # Median blur + depth_map = cv2.medianBlur(depth_map, 5) + + # Bilateral or Gaussian blur + if blur_type == "bilateral": + # Bilateral blur + depth_map = cv2.bilateralFilter(depth_map, 5, 1.5, 2.0) + elif blur_type == "gaussian": + # Gaussian blur + valid_pixels = depth_map > 0.1 + blurred = cv2.GaussianBlur(depth_map, (5, 5), 0) + depth_map[valid_pixels] = blurred[valid_pixels] + + # Invert + valid_pixels = depth_map > 0.1 + depth_map[valid_pixels] = max_depth - depth_map[valid_pixels] + + return depth_map + + +def fill_in_multiscale( + depth_map, + max_depth=100.0, + dilation_kernel_far=CROSS_KERNEL_3, + dilation_kernel_med=CROSS_KERNEL_5, + dilation_kernel_near=CROSS_KERNEL_7, + extrapolate=False, + blur_type="bilateral", + show_process=False, +): + """Slower, multi-scale dilation version with additional noise removal that + provides better qualitative results. + + Args: + depth_map: projected depths + max_depth: max depth value for inversion + dilation_kernel_far: dilation kernel to use for 30.0 < depths < 80.0 m + dilation_kernel_med: dilation kernel to use for 15.0 < depths < 30.0 m + dilation_kernel_near: dilation kernel to use for 0.1 < depths < 15.0 m + extrapolate:whether to extrapolate by extending depths to top of + the frame, and applying a 31x31 full kernel dilation + blur_type: + 'gaussian' - provides lower RMSE + 'bilateral' - preserves local structure (recommended) + show_process: saves process images into an OrderedDict + + Returns: + depth_map: dense depth map + process_dict: OrderedDict of process images + """ + + # Convert to float32 + depths_in = np.float32(depth_map) + + # Calculate bin masks before inversion + valid_pixels_near = (depths_in > 0.1) & (depths_in <= 15.0) + valid_pixels_med = (depths_in > 15.0) & (depths_in <= 30.0) + valid_pixels_far = depths_in > 30.0 + + # Invert (and offset) + s1_inverted_depths = np.copy(depths_in) + valid_pixels = s1_inverted_depths > 0.1 + s1_inverted_depths[valid_pixels] = max_depth - s1_inverted_depths[valid_pixels] + + # Multi-scale dilation + dilated_far = cv2.dilate( + np.multiply(s1_inverted_depths, valid_pixels_far), dilation_kernel_far + ) + dilated_med = cv2.dilate( + np.multiply(s1_inverted_depths, valid_pixels_med), dilation_kernel_med + ) + dilated_near = cv2.dilate( + np.multiply(s1_inverted_depths, valid_pixels_near), dilation_kernel_near + ) + + # Find valid pixels for each binned dilation + valid_pixels_near = dilated_near > 0.1 + valid_pixels_med = dilated_med > 0.1 + valid_pixels_far = dilated_far > 0.1 + + # Combine dilated versions, starting farthest to nearest + s2_dilated_depths = np.copy(s1_inverted_depths) + s2_dilated_depths[valid_pixels_far] = dilated_far[valid_pixels_far] + s2_dilated_depths[valid_pixels_med] = dilated_med[valid_pixels_med] + s2_dilated_depths[valid_pixels_near] = dilated_near[valid_pixels_near] + + # Small hole closure + s3_closed_depths = cv2.morphologyEx( + s2_dilated_depths, cv2.MORPH_CLOSE, FULL_KERNEL_5 + ) + + # Median blur to remove outliers + s4_blurred_depths = np.copy(s3_closed_depths) + blurred = cv2.medianBlur(s3_closed_depths, 5) + valid_pixels = s3_closed_depths > 0.1 + s4_blurred_depths[valid_pixels] = blurred[valid_pixels] + + # Calculate a top mask + top_mask = np.ones(depths_in.shape, dtype=np.bool) + for pixel_col_idx in range(s4_blurred_depths.shape[1]): + pixel_col = s4_blurred_depths[:, pixel_col_idx] + top_pixel_row = np.argmax(pixel_col > 0.1) + top_mask[0:top_pixel_row, pixel_col_idx] = False + + # Get empty mask + valid_pixels = s4_blurred_depths > 0.1 + empty_pixels = ~valid_pixels & top_mask + + # Hole fill + dilated = cv2.dilate(s4_blurred_depths, FULL_KERNEL_9) + s5_dilated_depths = np.copy(s4_blurred_depths) + s5_dilated_depths[empty_pixels] = dilated[empty_pixels] + + # Extend highest pixel to top of image or create top mask + s6_extended_depths = np.copy(s5_dilated_depths) + top_mask = np.ones(s5_dilated_depths.shape, dtype=np.bool) + + top_row_pixels = np.argmax(s5_dilated_depths > 0.1, axis=0) + top_pixel_values = s5_dilated_depths[ + top_row_pixels, range(s5_dilated_depths.shape[1]) + ] + + for pixel_col_idx in range(s5_dilated_depths.shape[1]): + if extrapolate: + s6_extended_depths[ + 0 : top_row_pixels[pixel_col_idx], pixel_col_idx + ] = top_pixel_values[pixel_col_idx] + else: + # Create top mask + top_mask[0 : top_row_pixels[pixel_col_idx], pixel_col_idx] = False + + # Fill large holes with masked dilations + s7_blurred_depths = np.copy(s6_extended_depths) + for i in range(6): + empty_pixels = (s7_blurred_depths < 0.1) & top_mask + dilated = cv2.dilate(s7_blurred_depths, FULL_KERNEL_5) + s7_blurred_depths[empty_pixels] = dilated[empty_pixels] + + # Median blur + blurred = cv2.medianBlur(s7_blurred_depths, 5) + valid_pixels = (s7_blurred_depths > 0.1) & top_mask + s7_blurred_depths[valid_pixels] = blurred[valid_pixels] + + if blur_type == "gaussian": + # Gaussian blur + blurred = cv2.GaussianBlur(s7_blurred_depths, (5, 5), 0) + valid_pixels = (s7_blurred_depths > 0.1) & top_mask + s7_blurred_depths[valid_pixels] = blurred[valid_pixels] + elif blur_type == "bilateral": + # Bilateral blur + blurred = cv2.bilateralFilter(s7_blurred_depths, 5, 0.5, 2.0) + s7_blurred_depths[valid_pixels] = blurred[valid_pixels] + + # Invert (and offset) + s8_inverted_depths = np.copy(s7_blurred_depths) + valid_pixels = np.where(s8_inverted_depths > 0.1) + s8_inverted_depths[valid_pixels] = max_depth - s8_inverted_depths[valid_pixels] + + depths_out = s8_inverted_depths + + process_dict = None + if show_process: + process_dict = collections.OrderedDict() + + process_dict["s0_depths_in"] = depths_in + + process_dict["s1_inverted_depths"] = s1_inverted_depths + process_dict["s2_dilated_depths"] = s2_dilated_depths + process_dict["s3_closed_depths"] = s3_closed_depths + process_dict["s4_blurred_depths"] = s4_blurred_depths + process_dict["s5_combined_depths"] = s5_dilated_depths + process_dict["s6_extended_depths"] = s6_extended_depths + process_dict["s7_blurred_depths"] = s7_blurred_depths + process_dict["s8_inverted_depths"] = s8_inverted_depths + + process_dict["s9_depths_out"] = depths_out + + return depths_out, process_dict diff --git a/spot_rl_experiments/spot_rl/utils/generate_place_goal.py b/spot_rl_experiments/spot_rl/utils/generate_place_goal.py new file mode 100644 index 00000000..ade4d81d --- /dev/null +++ b/spot_rl_experiments/spot_rl/utils/generate_place_goal.py @@ -0,0 +1,38 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import magnum as mn +import numpy as np +from spot_rl.envs.base_env import SpotBaseEnv +from spot_wrapper.spot import Spot + +EE_GRIPPER_OFFSET = [0.2, 0.0, 0.05] + + +def get_global_place_target(spot: Spot): + position, rotation = spot.get_base_transform_to("link_wr1") + position = [position.x, position.y, position.z] + rotation = [rotation.x, rotation.y, rotation.z, rotation.w] + wrist_T_base = SpotBaseEnv.spot2habitat_transform(position, rotation) + gripper_T_base = wrist_T_base @ mn.Matrix4.translation( + mn.Vector3(EE_GRIPPER_OFFSET) + ) + base_place_target_habitat = np.array(gripper_T_base.translation) + base_place_target = base_place_target_habitat[[0, 2, 1]] + + x, y, yaw = spot.get_xy_yaw() + base_T_global = mn.Matrix4.from_( + mn.Matrix4.rotation_z(mn.Rad(yaw)).rotation(), + mn.Vector3(mn.Vector3(x, y, 0.5)), + ) + global_place_target = base_T_global.transform_point(base_place_target) + + return global_place_target + + +if __name__ == "__main__": + spot = Spot("PlaceGoalGenerator") + global_place_target = get_global_place_target(spot) + print(global_place_target) diff --git a/spot_rl_experiments/spot_rl/utils/helper_nodes.py b/spot_rl_experiments/spot_rl/utils/helper_nodes.py new file mode 100644 index 00000000..0b865cfb --- /dev/null +++ b/spot_rl_experiments/spot_rl/utils/helper_nodes.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import argparse +import time + +import numpy as np +import rospy +from spot_rl.utils.utils import ros_topics as rt +from spot_wrapper.spot import Spot +from spot_wrapper.utils import say +from std_msgs.msg import Float32MultiArray, String + +NAV_POSE_BUFFER_LEN = 1 + + +class SpotRosProprioceptionPublisher: + def __init__(self, spot): + rospy.init_node("spot_ros_proprioception_node", disable_signals=True) + self.spot = spot + + # Instantiate filtered image publishers + self.pub = rospy.Publisher(rt.ROBOT_STATE, Float32MultiArray, queue_size=1) + self.last_publish = time.time() + rospy.loginfo("[spot_ros_proprioception_node]: Publishing has started.") + + self.nav_pose_buff = None + self.buff_idx = 0 + + def publish_msgs(self): + st = time.time() + robot_state = self.spot.get_robot_state() + msg = Float32MultiArray() + xy_yaw = self.spot.get_xy_yaw(robot_state=robot_state, use_boot_origin=True) + if self.nav_pose_buff is None: + self.nav_pose_buff = np.tile(xy_yaw, [NAV_POSE_BUFFER_LEN, 1]) + else: + self.nav_pose_buff[self.buff_idx] = xy_yaw + self.buff_idx = (self.buff_idx + 1) % NAV_POSE_BUFFER_LEN + xy_yaw = np.mean(self.nav_pose_buff, axis=0) + + joints = self.spot.get_arm_proprioception(robot_state=robot_state).values() + + position, rotation = self.spot.get_base_transform_to("link_wr1") + gripper_transform = [position.x, position.y, position.z] + [ + rotation.x, + rotation.y, + rotation.z, + rotation.w, + ] + + msg.data = np.array( + list(xy_yaw) + [j.position.value for j in joints] + gripper_transform, + dtype=np.float32, + ) + + # Limit publishing to 10 Hz max + if time.time() - self.last_publish > 1 / 10: + self.pub.publish(msg) + rospy.loginfo( + f"[spot_ros_proprioception_node]: " + "Proprioception retrieval / publish time: " + f"{1/(time.time() - st):.4f} / " + f"{1/(time.time() - self.last_publish):.4f} Hz" + ) + self.last_publish = time.time() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-p", "--proprioception", action="store_true") + parser.add_argument("-t", "--text-to-speech", action="store_true") + args = parser.parse_args() + + if args.text_to_speech: + tts_callback = lambda msg: say(msg.data) # noqa + rospy.init_node("spot_ros_tts_node", disable_signals=True) + rospy.Subscriber(rt.TEXT_TO_SPEECH, String, tts_callback, queue_size=1) + rospy.loginfo("[spot_ros_tts_node]: Listening for text to dictate.") + rospy.spin() + elif args.proprioception: + name = "SpotRosProprioceptionPublisher" + node = SpotRosProprioceptionPublisher(Spot(name)) + while not rospy.is_shutdown(): + node.publish_msgs() + else: + raise RuntimeError("One and only one arg must be provided.") diff --git a/spot_rl_experiments/spot_rl/utils/img_publishers.py b/spot_rl_experiments/spot_rl/utils/img_publishers.py new file mode 100644 index 00000000..9446d62e --- /dev/null +++ b/spot_rl_experiments/spot_rl/utils/img_publishers.py @@ -0,0 +1,486 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import argparse +import os.path as osp +import subprocess +import time +from copy import deepcopy +from typing import Any, List + +import blosc +import cv2 +import numpy as np +import rospy +from cv_bridge import CvBridge +from sensor_msgs.msg import Image +from spot_rl.utils.depth_map_utils import filter_depth +from spot_wrapper.spot import Spot +from spot_wrapper.spot import SpotCamIds as Cam +from spot_wrapper.spot import image_response_to_cv2, scale_depth_img +from std_msgs.msg import ( + ByteMultiArray, + Header, + MultiArrayDimension, + MultiArrayLayout, + String, +) + +try: + from spot_rl.utils.mask_rcnn_utils import ( + generate_mrcnn_detections, + get_deblurgan_model, + get_mrcnn_model, + pred2string, + ) +except ModuleNotFoundError: + pass + +# owlvit +from spot_rl.models import OwlVit +from spot_rl.utils.stopwatch import Stopwatch +from spot_rl.utils.utils import construct_config +from spot_rl.utils.utils import ros_topics as rt + +MAX_PUBLISH_FREQ = 20 +MAX_DEPTH = 3.5 +MAX_HAND_DEPTH = 1.7 + + +class SpotImagePublisher: + name = "" + publisher_topics = List[str] + publish_msg_type = Image + + def __init__(self): + rospy.init_node(self.name, disable_signals=True) + self.cv_bridge = CvBridge() + self.last_publish = time.time() + self.pubs = { + k: rospy.Publisher(k, self.publish_msg_type, queue_size=1, tcp_nodelay=True) + for k in self.publisher_topics + } + rospy.loginfo(f"[{self.name}]: Publisher initialized.") + + def publish(self): + # if st < self.last_publish + 1 / MAX_PUBLISH_FREQ: + # time.sleep(0.01) + # return + self._publish() + self.last_publish = time.time() + + def cv2_to_msg(self, *args, **kwargs) -> Image: + return self.cv_bridge.cv2_to_imgmsg(*args, **kwargs) + + def msg_to_cv2(self, *args, **kwargs) -> np.array: + return self.cv_bridge.imgmsg_to_cv2(*args, **kwargs) + + def _publish(self): + raise NotImplementedError + + +class SpotLocalRawImagesPublisher(SpotImagePublisher): + name = "spot_local_raw_images_publisher" + publisher_topics = [rt.HEAD_DEPTH, rt.HAND_DEPTH, rt.HAND_RGB] + sources = [ + Cam.FRONTRIGHT_DEPTH, + Cam.FRONTLEFT_DEPTH, + Cam.HAND_DEPTH_IN_HAND_COLOR_FRAME, + Cam.HAND_COLOR, + ] + + def __init__(self, spot): + super().__init__() + self.spot = spot + + def _publish(self): + image_responses = self.spot.get_image_responses(self.sources, quality=100) + imgs_list = [image_response_to_cv2(r) for r in image_responses] + imgs = {k: v for k, v in zip(self.sources, imgs_list)} + + head_depth = np.hstack([imgs[Cam.FRONTRIGHT_DEPTH], imgs[Cam.FRONTLEFT_DEPTH]]) + + head_depth = self._scale_depth(head_depth, head_depth=True) + hand_depth = self._scale_depth(imgs[Cam.HAND_DEPTH_IN_HAND_COLOR_FRAME]) + hand_rgb = imgs[Cam.HAND_COLOR] + + msgs = self.imgs_to_msgs(head_depth, hand_depth, hand_rgb) + + for topic, msg in zip(self.pubs.keys(), msgs): + self.pubs[topic].publish(msg) + + def imgs_to_msgs(self, head_depth, hand_depth, hand_rgb): + head_depth_msg = self.cv2_to_msg(head_depth, "mono8") + hand_depth_msg = self.cv2_to_msg(hand_depth, "mono8") + hand_rgb_msg = self.cv2_to_msg(hand_rgb, "bgr8") + + timestamp = rospy.Time.now() + head_depth_msg.header = Header(stamp=timestamp) + hand_depth_msg.header = Header(stamp=timestamp) + hand_rgb_msg.header = Header(stamp=timestamp) + + return head_depth_msg, hand_depth_msg, hand_rgb_msg + + @staticmethod + def _scale_depth(img, head_depth=False): + img = scale_depth_img( + img, max_depth=MAX_DEPTH if head_depth else MAX_HAND_DEPTH + ) + return np.uint8(img * 255.0) + + +class SpotLocalCompressedImagesPublisher(SpotLocalRawImagesPublisher): + name = "spot_local_compressed_images_publisher" + publisher_topics = [rt.COMPRESSED_IMAGES] + publish_msg_type = ByteMultiArray + + def imgs_to_msgs(self, head_depth, hand_depth, hand_rgb): + head_depth_bytes = blosc.pack_array( + head_depth, cname="zstd", clevel=3, shuffle=blosc.NOSHUFFLE + ) + hand_depth_bytes = blosc.pack_array( + hand_depth, cname="zstd", clevel=3, shuffle=blosc.NOSHUFFLE + ) + hand_rgb_bytes = np.array(cv2.imencode(".jpg", hand_rgb)[1]) + hand_rgb_bytes = (hand_rgb_bytes.astype(int) - 128).astype(np.int8) + topic2bytes = { + rt.HEAD_DEPTH: head_depth_bytes, + rt.HAND_DEPTH: hand_depth_bytes, + rt.HAND_RGB: hand_rgb_bytes, + } + topic2details = { + topic: { + "dims": MultiArrayDimension(label=topic, size=len(img_bytes)), + "bytes": img_bytes, + } + for topic, img_bytes in topic2bytes.items() + } + + depth_bytes = b"" + rgb_bytes, depth_dims, rgb_dims = [], [], [] + for topic, details in topic2details.items(): + if "depth" in topic: + depth_bytes += details["bytes"] + depth_dims.append(details["dims"]) + else: + rgb_bytes.append(details["bytes"]) + rgb_dims.append(details["dims"]) + depth_bytes = np.frombuffer(depth_bytes, dtype=np.uint8).astype(int) - 128 + bytes_data = np.concatenate([depth_bytes, *rgb_bytes]) + timestamp = str(time.time()) + timestamp_dim = MultiArrayDimension(label=timestamp, size=0) + dims = depth_dims + rgb_dims + [timestamp_dim] + msg = ByteMultiArray(layout=MultiArrayLayout(dim=dims), data=bytes_data) + return [msg] + + +class SpotProcessedImagesPublisher(SpotImagePublisher): + subscriber_topic = "" + subscriber_msg_type = Image + + def __init__(self): + super().__init__() + self.img_msg = None + rospy.Subscriber( + self.subscriber_topic, self.subscriber_msg_type, self.cb, queue_size=1 + ) + rospy.loginfo(f"[{self.name}]: is waiting for images...") + while self.img_msg is None: + pass + rospy.loginfo(f"[{self.name}]: has received images!") + self.updated = True + + def publish(self): + if not self.updated: + return + super().publish() + self.updated = False + + def cb(self, msg: Image): + self.img_msg = msg + self.updated = True + + +class SpotDecompressingRawImagesPublisher(SpotProcessedImagesPublisher): + name = "spot_decompressing_raw_images_publisher" + publisher_topics = [rt.HEAD_DEPTH, rt.HAND_DEPTH, rt.HAND_RGB] + subscriber_topic = rt.COMPRESSED_IMAGES + subscriber_msg_type = ByteMultiArray + + def _publish(self): + if self.img_msg is None: + return + img_msg = deepcopy(self.img_msg) + + py_timestamp = float(img_msg.layout.dim[-1].label) + latency = time.time() - py_timestamp + latency_msg = f"[{self.name}]: Latency is {latency:.2f} sec" + if latency < 0.5: + rospy.loginfo(latency_msg + ".") + else: + rospy.logwarn(latency_msg + "!") + timestamp = rospy.Time.from_sec(py_timestamp) + + byte_data = (np.array(img_msg.data) + 128).astype(np.uint8) + size_and_labels = [ + (int(dim.size), str(dim.label)) for dim in img_msg.layout.dim + ] + start = 0 + imgs = {} + for size, label in size_and_labels: + end = start + size + if "depth" in label: + img = blosc.unpack_array(byte_data[start:end].tobytes()) + imgs[label] = img + elif "rgb" in label: + rgb_bytes = byte_data[start:end] + img = cv2.imdecode(rgb_bytes, cv2.IMREAD_COLOR) + else: # timestamp + continue + imgs[label] = img + start += size + + head_depth_msg = self.cv2_to_msg(imgs[rt.HEAD_DEPTH], "mono8") + hand_depth_msg = self.cv2_to_msg(imgs[rt.HAND_DEPTH], "mono8") + hand_rgb_msg = self.cv2_to_msg(imgs[rt.HAND_RGB], "bgr8") + + head_depth_msg.header = Header(stamp=timestamp) + hand_depth_msg.header = Header(stamp=timestamp) + hand_rgb_msg.header = Header(stamp=timestamp) + + self.pubs[rt.HEAD_DEPTH].publish(head_depth_msg) + self.pubs[rt.HAND_DEPTH].publish(hand_depth_msg) + self.pubs[rt.HAND_RGB].publish(hand_rgb_msg) + + +class SpotFilteredDepthImagesPublisher(SpotProcessedImagesPublisher): + max_depth = 0.0 + filtered_depth_topic = "" + + def _publish(self): + depth = self.msg_to_cv2(self.img_msg) + filtered_depth = filter_depth(depth, max_depth=self.max_depth) + img_msg = self.cv_bridge.cv2_to_imgmsg(filtered_depth, "mono8") + img_msg.header = self.img_msg.header + self.pubs[self.publisher_topics[0]].publish(img_msg) + + +class SpotFilteredHeadDepthImagesPublisher(SpotFilteredDepthImagesPublisher): + name = "spot_filtered_head_depth_images_publisher" + subscriber_topic = rt.HEAD_DEPTH + max_depth = MAX_DEPTH + publisher_topics = [rt.FILTERED_HEAD_DEPTH] + + +class SpotFilteredHandDepthImagesPublisher(SpotFilteredDepthImagesPublisher): + name = "spot_filtered_hand_depth_images_publisher" + subscriber_topic = rt.HAND_DEPTH + max_depth = MAX_HAND_DEPTH + publisher_topics = [rt.FILTERED_HAND_DEPTH] + + +class SpotBoundingBoxPublisher(SpotProcessedImagesPublisher): + + # TODO: We eventually want to change this name as well as the publisher topic + name = "spot_mrcnn_publisher" + subscriber_topic = rt.HAND_RGB + publisher_topics = [rt.MASK_RCNN_VIZ_TOPIC] + + def __init__(self, model): + super().__init__() + self.model = model + self.detection_topic = rt.DETECTIONS_TOPIC + + self.config = config = construct_config() + self.image_scale = config.IMAGE_SCALE + self.deblur_gan = get_deblurgan_model(config) + self.grayscale = self.config.GRAYSCALE_MASK_RCNN + + self.pubs[self.detection_topic] = rospy.Publisher( + self.detection_topic, String, queue_size=1, tcp_nodelay=True + ) + self.viz_topic = rt.MASK_RCNN_VIZ_TOPIC + + def preprocess_image(self, img): + if self.image_scale != 1.0: + img = cv2.resize( + img, + (0, 0), + fx=self.image_scale, + fy=self.image_scale, + interpolation=cv2.INTER_AREA, + ) + + if self.deblur_gan is not None: + img = self.deblur_gan(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + + if self.grayscale: + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + return img + + def _publish(self): + stopwatch = Stopwatch() + header = self.img_msg.header + timestamp = header.stamp + hand_rgb = self.msg_to_cv2(self.img_msg) + + # Internal model + hand_rgb_preprocessed = self.preprocess_image(hand_rgb) + bbox_data, viz_img = self.model.inference( + hand_rgb_preprocessed, timestamp, stopwatch + ) + + # publish data + self.publish_bbox_data(bbox_data) + self.publish_viz_img(viz_img, header) + + stopwatch.print_stats() + + def publish_bbox_data(self, bbox_data): + self.pubs[self.detection_topic].publish(bbox_data) + + def publish_viz_img(self, viz_img, header): + viz_img_msg = self.cv2_to_msg(viz_img) + viz_img_msg.header = header + self.pubs[self.viz_topic].publish(viz_img_msg) + + +class OWLVITModel: + def __init__(self, score_threshold=0.05, show_img=False): + self.config = config = construct_config() + self.owlvit = OwlVit([["ball"]], score_threshold, show_img) + self.image_scale = config.IMAGE_SCALE + rospy.loginfo("[OWLVIT]: Models loaded.") + + def inference(self, hand_rgb, timestamp, stopwatch): + params = rospy.get_param("/object_target").split(",") + self.owlvit.update_label([params]) + bbox_xy, viz_img = self.owlvit.run_inference_and_return_img(hand_rgb) + + if bbox_xy is not None and bbox_xy != []: + detections = [] + for detection in bbox_xy: + str_det = f'{detection[0]},{detection[1]},{",".join([str(i) for i in detection[2]])}' + detections.append(str_det) + bbox_xy_string = ";".join(detections) + else: + bbox_xy_string = "None" + detections_str = f"{int(timestamp.nsecs)}|{bbox_xy_string}" + + return detections_str, viz_img + + +class MRCNNModel: + def __init__(self): + self.config = config = construct_config() + self.mrcnn = get_mrcnn_model(config) + rospy.loginfo("[MRCNN]: Models loaded.") + + def inference(self, hand_rgb, timestamp, stopwatch): + img = hand_rgb + pred = self.mrcnn.inference(img) + if stopwatch is not None: + stopwatch.record("mrcnn_secs") + detections_str = f"{int(timestamp.nsecs)}|{pred2string(pred)}" + viz_img = self.mrcnn.visualize_inference(img, pred) + return detections_str, viz_img + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--filter-head-depth", action="store_true") + parser.add_argument("--filter-hand-depth", action="store_true") + parser.add_argument("--decompress", action="store_true") + parser.add_argument("--raw", action="store_true") + parser.add_argument("--compress", action="store_true") + parser.add_argument("--owlvit", action="store_true") + parser.add_argument("--mrcnn", action="store_true") + parser.add_argument("--core", action="store_true", help="running on the Core") + parser.add_argument("--listen", action="store_true", help="listening to Core") + parser.add_argument( + "--local", action="store_true", help="fully local robot connection" + ) + parser.add_argument( + "--bounding_box_detector", + choices=["owlvit", "mrcnn"], + default="owlvit", + help="bounding box detector model to use (owlvit or maskrcnn)", + ) + + args = parser.parse_args() + # assert ( + # len([i[1] for i in args._get_kwargs() if i[1]]) == 1 + # ), "One and only one arg must be provided." + + filter_head_depth = args.filter_head_depth + filter_hand_depth = args.filter_hand_depth + decompress = args.decompress + raw = args.raw + compress = args.compress + core = args.core + listen = args.listen + local = args.local + bounding_box_detector = args.bounding_box_detector + mrcnn = args.mrcnn + owlvit = args.owlvit + + node = None # type: Any + model = None # type: Any + if filter_head_depth: + node = SpotFilteredHeadDepthImagesPublisher() + elif filter_hand_depth: + node = SpotFilteredHandDepthImagesPublisher() + elif mrcnn: + model = MRCNNModel() + node = SpotBoundingBoxPublisher(model) + elif owlvit: + # TODO dynamic label + rospy.set_param("object_target", "ball") + model = OWLVITModel() + node = SpotBoundingBoxPublisher(model) + elif decompress: + node = SpotDecompressingRawImagesPublisher() + elif raw or compress: + name = "LocalRawImagesPublisher" if raw else "LocalCompressedImagesPublisher" + spot = Spot(name) + if raw: + node = SpotLocalRawImagesPublisher(spot) + else: + node = SpotLocalCompressedImagesPublisher(spot) + else: + assert core or listen or local, "This should be impossible." + + if core or listen or local: + if core: + flags = ["--compress"] + else: + flags = [ + "--filter-head-depth", + "--filter-hand-depth", + f"--{bounding_box_detector}", + ] + if listen: + flags.append("--decompress") + elif local: + flags.append("--raw") + else: + raise RuntimeError("This should be impossible.") + cmds = [f"python {osp.abspath(__file__)} {flag}" for flag in flags] + processes = [subprocess.Popen(cmd, shell=True) for cmd in cmds] + try: + while all([p.poll() is None for p in processes]): + pass + finally: + for p in processes: + try: + p.kill() + except Exception: + pass + else: + while not rospy.is_shutdown(): + node.publish() diff --git a/spot_rl_experiments/spot_rl/utils/mask_rcnn_utils.py b/spot_rl_experiments/spot_rl/utils/mask_rcnn_utils.py new file mode 100644 index 00000000..12c66860 --- /dev/null +++ b/spot_rl_experiments/spot_rl/utils/mask_rcnn_utils.py @@ -0,0 +1,80 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import os.path as osp + +import cv2 +from deblur_gan.predictor import DeblurGANv2 +from mask_rcnn_detectron2.inference import MaskRcnnInference + + +def generate_mrcnn_detections( + img, scale, mrcnn, grayscale=True, deblurgan=None, return_img=False, stopwatch=None +): + if scale != 1.0: + img = cv2.resize( + img, + (0, 0), + fx=scale, + fy=scale, + interpolation=cv2.INTER_AREA, + ) + if deblurgan is not None: + img = deblurgan(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + if stopwatch is not None: + stopwatch.record("deblur_secs") + if grayscale: + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + detections = mrcnn.inference(img) + if stopwatch is not None: + stopwatch.record("mrcnn_secs") + + if return_img: + return detections, img + + return detections + + +def pred2string(pred): + detections = pred["instances"] + if len(detections) == 0: + return "None" + + detection_str = [] + for det_idx in range(len(detections)): + class_id = detections.pred_classes[det_idx] + score = detections.scores[det_idx] + x1, y1, x2, y2 = detections.pred_boxes[det_idx].tensor.squeeze(0) + det_attrs = [str(i.item()) for i in [class_id, score, x1, y1, x2, y2]] + detection_str.append(",".join(det_attrs)) + detection_str = ";".join(detection_str) + return detection_str + + +def get_mrcnn_model(config): + mask_rcnn_weights = ( + config.WEIGHTS.MRCNN_50 if config.USE_FPN_R50 else config.WEIGHTS.MRCNN + ) + mask_rcnn_device = config.DEVICE + config_path = "50" if config.USE_FPN_R50 else "101" + mrcnn = MaskRcnnInference( + mask_rcnn_weights, + score_thresh=0.7, + device=mask_rcnn_device, + config_path=config_path, + ) + return mrcnn + + +def get_deblurgan_model(config): + if config.USE_DEBLURGAN and config.USE_MRCNN: + weights_path = config.WEIGHTS.DEBLURGAN + model_name = osp.basename(weights_path).split(".")[0] + print("Loading DeblurGANv2 with:", weights_path) + model = DeblurGANv2(weights_path=weights_path, model_name=model_name) + return model + return None diff --git a/spot_rl_experiments/spot_rl/utils/remote_spot.py b/spot_rl_experiments/spot_rl/utils/remote_spot.py new file mode 100644 index 00000000..8bc493fe --- /dev/null +++ b/spot_rl_experiments/spot_rl/utils/remote_spot.py @@ -0,0 +1,144 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +# TODO: Support the following: +# self.x, self.y, self.yaw = self.spot.xy_yaw_global_to_home( +# position, rotation = self.spot.get_base_transform_to("link_wr1") + + +""" +This class allows you to control Spot as if you had a lease to actuate its motors, +but will actually just relay any motor commands to the robot's onboard Core. The Core +is the one that actually possesses the lease and sends motor commands to Spot via +Ethernet (faster, more reliable). + +The message relaying is executed with ROS topic publishing / subscribing. + +Very hacky. +""" + +import json +import time + +import rospy +from spot_wrapper.spot import Spot +from std_msgs.msg import Bool, String + +ROBOT_CMD_TOPIC = "/remote_robot_cmd" +CMD_ENDED_TOPIC = "/remote_robot_cmd_ended" +KILL_REMOTE_ROBOT = "/kill_remote_robot" +INIT_REMOTE_ROBOT = "/init_remote_robot" + + +def isiterable(var): + try: + iter(var) + except TypeError: + return False + else: + return True + + +class RemoteSpot(Spot): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # This determines whether the Core has confirmed the last cmd has ended + self.cmd_ended = False + # This subscriber updates the above attribute + rospy.Subscriber(CMD_ENDED_TOPIC, Bool, self.cmd_ended_callback, queue_size=1) + + # This publisher sends the desired command to the Core + self.pub = rospy.Publisher(ROBOT_CMD_TOPIC, String, queue_size=1) + self.remote_robot_killer = rospy.Publisher( + KILL_REMOTE_ROBOT, Bool, queue_size=1 + ) + + # This publisher starts the remote robot + self.init_robot = rospy.Publisher(INIT_REMOTE_ROBOT, Bool, queue_size=1) + + self.error_on_no_response = True + + def cmd_ended_callback(self, msg): + self.cmd_ended = msg.data + + def send_cmd(self, cmd_name, *args, **kwargs): + cmd_with_args_str = f"{cmd_name}" + if args: + cmd_with_args_str += ";" + ";".join([self.arg2str(i) for i in args]) + if kwargs: + cmd_with_args_str += ";" + str(kwargs) + self.pub.publish(cmd_with_args_str) + + @staticmethod + def arg2str(arg): + if isinstance(arg, str): + return arg + if type(arg) in [float, int, bool]: + return str(arg) + elif isiterable(arg): + return f"np.array([{','.join([str(i) for i in arg])}])" + else: + return str(arg) + + def blocking(self, timeout): + start_time = time.time() + self.cmd_ended = False + while not self.cmd_ended and time.time() < start_time + timeout: + # We need to block until we receive confirmation from the Core that the + # grasp has ended + time.sleep(0.1) + self.cmd_ended = False + + if time.time() > start_time + timeout: + if self.error_on_no_response: + raise TimeoutError( + "Did not hear back from remote robot before timeout." + ) + return False + + return True + + def grasp_hand_depth(self, *args, **kwargs): + assert "timeout" in kwargs + self.send_cmd("grasp_hand_depth", *args, **kwargs) + return self.blocking(timeout=kwargs["timeout"]) + + def set_arm_joint_positions( + self, positions, travel_time=1.0, max_vel=2.5, max_acc=15 + ): + self.send_cmd( + "set_arm_joint_positions", + positions, + travel_time, + max_vel, + max_acc, + ) + + def open_gripper(self): + self.send_cmd("open_gripper") + + def set_base_velocity(self, *args, **kwargs): + self.send_cmd("set_base_velocity", *args, **kwargs) + + def set_base_vel_and_arm_pos(self, *args, **kwargs): + self.send_cmd("set_base_vel_and_arm_pos", *args, **kwargs) + + def dock(self, *args, **kwargs): + self.send_cmd("dock", *args, **kwargs) + return self.blocking(timeout=20) + + def power_on(self, *args, **kwargs): + self.init_robot.publish(True) + time.sleep(5) + self.send_cmd("power_on") + return self.blocking(timeout=20) + + def blocking_stand(self, *args, **kwargs): + self.send_cmd("blocking_stand") + return self.blocking(timeout=10) + + def power_off(self, *args, **kwargs): + print("[remote_spot.py]: Asking robot to power off...") + self.remote_robot_killer.publish(True) diff --git a/spot_rl_experiments/spot_rl/utils/remote_spot_listener.py b/spot_rl_experiments/spot_rl/utils/remote_spot_listener.py new file mode 100644 index 00000000..7c5016f7 --- /dev/null +++ b/spot_rl_experiments/spot_rl/utils/remote_spot_listener.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +""" +The code here should be by the Core only. This will relay any received commands straight +to the robot from the Core via Ethernet. +""" + +import numpy as np # DON'T REMOVE IMPORT +import rospy +from spot_rl.utils.utils import ros_topics as rt +from spot_wrapper.spot import Spot +from std_msgs.msg import Bool, String + + +class RemoteSpotListener: + def __init__(self, spot): + self.spot = spot + assert spot.spot_lease is not None, "Need motor control of Spot!" + + # This subscriber executes received cmds + rospy.Subscriber(rt.ROBOT_CMD_TOPIC, String, self.execute_cmd, queue_size=1) + + # This publisher signals if a cmd has finished + self.pub = rospy.Publisher(rt.CMD_ENDED_TOPIC, Bool, queue_size=1) + + # This subscriber will kill the listener + rospy.Subscriber( + rt.KILL_REMOTE_ROBOT, Bool, self.kill_remote_robot, queue_size=1 + ) + + self.off = False + + def execute_cmd(self, msg): + if self.off: + return + + values = msg.data.split(";") + method_name, args = values[0], values[1:] + method = eval("self.spot." + method_name) + + cmd_str = f"self.spot.{method_name}({args if args else ''})" + rospy.loginfo(f"[RemoteSpotListener]: Executing: {cmd_str}") + + decoded_args = [eval(i) for i in args] + args_vec = [i for i in decoded_args if not isinstance(i, dict)] + kwargs = [i for i in decoded_args if isinstance(i, dict)] + assert len(kwargs) <= 1 + if not kwargs: + kwargs = {} + else: + kwargs = kwargs[0] + method(*args_vec, **kwargs) + self.pub.publish(True) + + def kill_remote_robot(self, msg): + rospy.loginfo("[RemoteSpotListener]: Powering robot off...") + self.spot.power_off() + self.off = True + rospy.signal_shutdown("Robot was powered off.") + exit() + + +class RemoteSpotMaster: + def __init__(self): + rospy.init_node("RemoteSpotMaster", disable_signals=True) + # This subscriber executes received cmds + rospy.Subscriber( + rt.INIT_REMOTE_ROBOT, Bool, self.init_remote_robot, queue_size=1 + ) + self.remote_robot_killer = rospy.Publisher( + rt.KILL_REMOTE_ROBOT, Bool, queue_size=1 + ) + self.lease = None + self.remote_robot_listener = None + rospy.loginfo("[RemoteSpotMaster]: Listening for requests to start robot...") + + def init_remote_robot(self, msg): + if self.lease is not None: + if not self.remote_robot_listener.off: + self.remote_robot_listener.power_off() + self.lease.__exit__(None, None, None) + self.remote_robot_listener = None + + spot = Spot("RemoteSpotListener") + rospy.loginfo("[RemoteSpotMaster]: Starting robot!") + self.lease = spot.get_lease(hijack=True) + self.remote_robot_listener = RemoteSpotListener(spot) + + +if __name__ == "__main__": + RemoteSpotMaster() + rospy.spin() diff --git a/spot_rl_experiments/spot_rl/utils/robot_subscriber.py b/spot_rl_experiments/spot_rl/utils/robot_subscriber.py new file mode 100644 index 00000000..562db7d2 --- /dev/null +++ b/spot_rl_experiments/spot_rl/utils/robot_subscriber.py @@ -0,0 +1,98 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from functools import partial + +import numpy as np +import rospy +from cv_bridge import CvBridge +from sensor_msgs.msg import Image +from spot_rl.utils.utils import ros_topics as rt +from spot_wrapper.spot import Spot +from std_msgs.msg import Float32MultiArray + +IMG_TOPICS = [ + rt.MASK_RCNN_VIZ_TOPIC, + rt.HEAD_DEPTH, + rt.HAND_DEPTH, + rt.HAND_RGB, + rt.FILTERED_HEAD_DEPTH, + rt.FILTERED_HAND_DEPTH, +] +NO_RAW_IMG_TOPICS = [ + rt.MASK_RCNN_VIZ_TOPIC, + rt.HAND_RGB, + rt.FILTERED_HEAD_DEPTH, + rt.FILTERED_HAND_DEPTH, +] + + +class SpotRobotSubscriberMixin: + node_name = "SpotRobotSubscriber" + no_raw = False + proprioception = True + + def __init__(self, spot=None, *args, **kwargs): + super().__init__(*args, **kwargs) + rospy.init_node(self.node_name, disable_signals=True) + self.cv_bridge = CvBridge() + + subscriptions = NO_RAW_IMG_TOPICS if self.no_raw else IMG_TOPICS + + # Maps a topic name to the latest msg from it + self.msgs = {topic: None for topic in subscriptions} + self.updated = {topic: False for topic in subscriptions} + + for img_topic in subscriptions: + rospy.Subscriber( + img_topic, + Image, + partial(self.img_callback, img_topic), + queue_size=1, + buff_size=2**30, + ) + rospy.loginfo(f"[{self.node_name}]: Waiting for images...") + while not all([self.msgs[s] is not None for s in subscriptions]): + pass + rospy.loginfo(f"[{self.node_name}]: Received images!") + + self.x = 0.0 + self.y = 0.0 + self.yaw = 0.0 + self.current_arm_pose = None + self.link_wr1_position, self.link_wr1_rotation = None, None + if self.proprioception: + rospy.Subscriber( + rt.ROBOT_STATE, + Float32MultiArray, + self.robot_state_callback, + queue_size=1, + ) + assert spot is not None + self.spot = spot + else: + self.spot = None + + self.pick_target = "None" + self.pick_object = "None" + self.place_target = "None" + + rospy.loginfo(f"[{self.node_name}]: Robot subscribing has started.") + + def img_callback(self, topic, msg): + self.msgs[topic] = msg + self.updated[topic] = True + + def robot_state_callback(self, msg): + x, y, yaw = msg.data[:3] + self.x, self.y, self.yaw = self.spot.xy_yaw_global_to_home(x, y, yaw) + self.current_arm_pose = msg.data[3:-7] + self.link_wr1_position, self.link_wr1_rotation = ( + msg.data[-7:][:3], + msg.data[-7:][3:], + ) + + def msg_to_cv2(self, *args, **kwargs) -> np.array: + return self.cv_bridge.imgmsg_to_cv2(*args, **kwargs) diff --git a/spot_rl_experiments/spot_rl/utils/run_local_parallel_inference.py b/spot_rl_experiments/spot_rl/utils/run_local_parallel_inference.py new file mode 100644 index 00000000..baea11cf --- /dev/null +++ b/spot_rl_experiments/spot_rl/utils/run_local_parallel_inference.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +# mypy: ignore-errors +import os.path as osp +import subprocess + +this_dir = osp.dirname(osp.abspath(__file__)) +local_parallel_inference = osp.join(this_dir, "img_publishers.py") + +cmds = [ + f"python {local_parallel_inference}", + f"python {local_parallel_inference} --nav", +] + +processes = [subprocess.Popen(cmd, shell=True) for cmd in cmds] +try: + while any([p.poll() is None for p in processes]): + pass +finally: + [p.kill() for p in processes] diff --git a/spot_rl_experiments/spot_rl/utils/run_parallel_inference.py b/spot_rl_experiments/spot_rl/utils/run_parallel_inference.py new file mode 100644 index 00000000..f8c6b546 --- /dev/null +++ b/spot_rl_experiments/spot_rl/utils/run_parallel_inference.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +# mypy: ignore-errors +import os.path as osp +import subprocess + +this_dir = osp.dirname(osp.abspath(__file__)) +depth_node_script = osp.join(this_dir, "depth_filter_node.py") +mask_rcnn_node_script = osp.join(this_dir, "mask_rcnn_utils.py") + +cmds = [ + f"python {depth_node_script}", + f"python {depth_node_script} --head", + f"python {mask_rcnn_node_script}", +] + +processes = [subprocess.Popen(cmd, shell=True) for cmd in cmds] +try: + while any([p.poll() is None for p in processes]): + pass +finally: + [p.kill() for p in processes] diff --git a/spot_rl_experiments/spot_rl/utils/spot_rl_launch_local.py b/spot_rl_experiments/spot_rl/utils/spot_rl_launch_local.py new file mode 100644 index 00000000..5ce8caf5 --- /dev/null +++ b/spot_rl_experiments/spot_rl/utils/spot_rl_launch_local.py @@ -0,0 +1,3 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/spot_rl_experiments/spot_rl/utils/stopwatch.py b/spot_rl_experiments/spot_rl/utils/stopwatch.py new file mode 100644 index 00000000..4ae9965c --- /dev/null +++ b/spot_rl_experiments/spot_rl/utils/stopwatch.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import time +from collections import OrderedDict, deque + +import numpy as np + + +class Stopwatch: + def __init__(self, window_size=50): + self.window_size = window_size + self.times = OrderedDict() + self.current_time = time.time() + + def reset(self): + self.current_time = time.time() + + def record(self, key): + if key not in self.times: + self.times[key] = deque(maxlen=self.window_size) + self.times[key].append(time.time() - self.current_time) + self.current_time = time.time() + + def print_stats(self, latest=False): + name2time = OrderedDict() + for k, v in self.times.items(): + if latest: + name2time[k] = v[-1] + else: + name2time[k] = np.mean(v) + name2time["total"] = np.sum(list(name2time.values())) + print(" ".join([f"{k}: {v:.4f}" for k, v in name2time.items()])) diff --git a/spot_rl_experiments/spot_rl/utils/utils.py b/spot_rl_experiments/spot_rl/utils/utils.py new file mode 100644 index 00000000..55f67238 --- /dev/null +++ b/spot_rl_experiments/spot_rl/utils/utils.py @@ -0,0 +1,122 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import argparse +import os.path as osp +from collections import OrderedDict + +import numpy as np +import yaml +from yacs.config import CfgNode as CN + +this_dir = osp.dirname(osp.abspath(__file__)) +spot_rl_dir = osp.join(osp.dirname(this_dir)) +spot_rl_experiments_dir = osp.join(osp.dirname(spot_rl_dir)) +configs_dir = osp.join(spot_rl_experiments_dir, "configs") +DEFAULT_CONFIG = osp.join(configs_dir, "config.yaml") +WAYPOINTS_YAML = osp.join(configs_dir, "waypoints.yaml") + +ROS_TOPICS = osp.join(configs_dir, "ros_topic_names.yaml") +ros_topics = CN() +ros_topics.set_new_allowed(True) +ros_topics.merge_from_file(ROS_TOPICS) + + +def get_waypoint_yaml(waypoint_file=WAYPOINTS_YAML): + with open(waypoint_file) as f: + return yaml.safe_load(f) + + +def get_default_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("-o", "--opts", nargs="*", default=[]) + return parser + + +def construct_config(opts=None): + if opts is None: + opts = [] + config = CN() + config.set_new_allowed(True) + config.merge_from_file(DEFAULT_CONFIG) + config.merge_from_list(opts) + + new_weights = {} + for k, v in config.WEIGHTS.items(): + if not osp.isfile(v): + new_v = osp.join(spot_rl_experiments_dir, v) + if not osp.isfile(new_v): + raise KeyError(f"Neither {v} nor {new_v} exist!") + new_weights[k] = new_v + config.WEIGHTS.update(new_weights) + + return config + + +def nav_target_from_waypoints(waypoint): + waypoints_yaml = get_waypoint_yaml(WAYPOINTS_YAML) + goal_x, goal_y, goal_heading = waypoints_yaml["nav_targets"][waypoint] + return goal_x, goal_y, np.deg2rad(goal_heading) + + +def place_target_from_waypoints(waypoint): + waypoints_yaml = get_waypoint_yaml(WAYPOINTS_YAML) + return np.array(waypoints_yaml["place_targets"][waypoint]) + + +def closest_clutter(x, y, clutter_blacklist=None): + waypoints_yaml = get_waypoint_yaml(WAYPOINTS_YAML) + if clutter_blacklist is None: + clutter_blacklist = [] + clutter_locations = [ + (np.array(nav_target_from_waypoints(w)[:2]), w) + for w in waypoints_yaml["clutter"] + if w not in clutter_blacklist + ] + xy = np.array([x, y]) + dist_to_clutter = lambda i: np.linalg.norm(i[0] - xy) # noqa + _, waypoint_name = sorted(clutter_locations, key=dist_to_clutter)[0] + return waypoint_name, nav_target_from_waypoints(waypoint_name) + + +def object_id_to_nav_waypoint(object_id): + waypoints_yaml = get_waypoint_yaml(WAYPOINTS_YAML) + if isinstance(object_id, str): + for k, v in waypoints_yaml["object_targets"].items(): + if v[0] == object_id: + object_id = int(k) + break + if isinstance(object_id, str): + KeyError(f"{object_id} not a valid class name!") + place_nav_target_name = waypoints_yaml["object_targets"][object_id][1] + return place_nav_target_name, nav_target_from_waypoints(place_nav_target_name) + + +def object_id_to_object_name(object_id): + waypoints_yaml = get_waypoint_yaml(WAYPOINTS_YAML) + return waypoints_yaml["object_targets"][object_id][0] + + +def get_clutter_amounts(): + waypoints_yaml = get_waypoint_yaml(WAYPOINTS_YAML) + return waypoints_yaml["clutter_amounts"] + + +def arr2str(arr): + if arr is not None: + return f"[{', '.join([f'{i:.2f}' for i in arr])}]" + return + + +class FixSizeOrderedDict(OrderedDict): + def __init__(self, *args, maxlen=0, **kwargs): + self._maxlen = maxlen + super().__init__(*args, **kwargs) + + def __setitem__(self, key, value): + OrderedDict.__setitem__(self, key, value) + if self._maxlen > 0: + if len(self) > self._maxlen: + self.popitem(False) diff --git a/spot_rl_experiments/spot_rl/utils/waypoint_recorder.py b/spot_rl_experiments/spot_rl/utils/waypoint_recorder.py new file mode 100644 index 00000000..d4783eb5 --- /dev/null +++ b/spot_rl_experiments/spot_rl/utils/waypoint_recorder.py @@ -0,0 +1,391 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import os.path as osp +import sys +from typing import Dict + +import numpy as np +import ruamel.yaml +from spot_rl.utils.generate_place_goal import get_global_place_target +from spot_rl.utils.utils import get_default_parser +from spot_wrapper.spot import Spot + +spot_rl_dir = osp.abspath(__file__) +for _ in range(3): + spot_rl_dir = osp.dirname(spot_rl_dir) +WAYPOINT_YAML = osp.join(spot_rl_dir, "configs/waypoints.yaml") + + +def parse_arguments(args): + parser = get_default_parser() + parser.add_argument("-c", "--clutter", help="input:string -> clutter target name") + parser.add_argument( + "-p", "--place-target", help="input:string -> place target name" + ) + parser.add_argument("-n", "--nav-only", help="input:string -> nav target name") + parser.add_argument( + "-x", + "--create-file", + action="store_true", + help="input: not needed -> create a new waypoints.yaml file", + ) + args = parser.parse_args(args=args) + + return args + + +class YamlHandler: + """ + Class to handle reading and writing to yaml files + + How to use: + 1. Create a simple yaml file with the following format: + + place_targets: + test_receptacle: + - 3.0 + - 0.0 + - 0.8 + clutter: + - test_receptacle + clutter_amounts: + test_receptacle: 1 + object_targets: + 0: [penguin, test_receptacle] + nav_targets: + dock: + - 1.5 + - 0.0 + - 0.0 + test_receptacle: + - 2.5 + - 0.0 + - 0.0 + + 2. Create an instance of this class + 3. Read the yaml file using the read_yaml method as a dict + 4. Modify the yaml_dict outside of this class object as needed + 5. Write the yaml_dict yaml_file using the write_yaml method with the created instance + + Example: + yaml_handler = YamlHandler() + yaml_dict = yaml_handler.read_yaml(waypoint_file=waypoint_file) + yaml_dict["nav_targets"]["test_receptacle"] = [2.5, 0.0, 0.0] # Modify the yaml_dict + yaml_handler.write_yaml(waypoint_file=waypoint_file, yaml_dict=yaml_dict) + """ + + def __init__(self): + pass + + def construct_yaml_dict(self): + """ + Constructs and returns a simple yaml dict with "test_receptacle" as nav, place, and clutter target and dock as nav target + """ + init_yaml_dict = """ +place_targets: # i.e., where an object needs to be placed (x,y,z) + test_receptacle: + - 3.0 + - 0.0 + - 0.8 +clutter: # i.e., where an object is currently placed +# + - test_receptacle +clutter_amounts: # i.e., how much clutter exists in each receptacle +# : + test_receptacle: 1 +object_targets: # i.e., where an object belongs / needs to be placed + # : [, ] + 0: [penguin, test_receptacle] +nav_targets: # i.e., where the robot needs to navigate to (x,y,yaw) + dock: + - 1.5 + - 0.0 + - 0.0 + test_receptacle: + - 2.5 + - 0.0 + - 0.0""" + + return init_yaml_dict + + @staticmethod + def read_yaml(self, waypoint_file: str): + """ + Read a yaml file and returns a dict + + Args: + waypoint_file (str): path to yaml file + + Returns: + yaml_dict (dict): Contens of the yaml file as a dict if it exists, else an contructs a new simple yaml_dict + """ + yaml_dict = {} # type: Dict + yaml = ruamel.yaml.YAML() # defaults to round-trip if no parameters given + + # Read yaml file if it exists + if osp.exists(waypoint_file): + with open(waypoint_file, "r") as f: + print( + f"Reading waypoints from already existing waypoints.yaml file at {waypoint_file}" + ) + yaml_dict = yaml.load(f.read()) + else: + print( + f"Creating a new waypoints dict as waypoints.yaml does not exist on path {waypoint_file}" + ) + yaml_dict = yaml.load(self.construct_yaml_dict()) + + return yaml_dict + + @staticmethod + def write_yaml(self, waypoint_file: str, yaml_dict): + """ + Write the yaml_dict into the yaml_file. + If the yaml_file does not exist, it will be created. + + Args: + waypoint_file (str): path to yaml file + yaml_dict (dict): dict to be written to yaml file + """ + with open(waypoint_file, "w+") as f: + ruamel.yaml.dump(yaml_dict, f, Dumper=ruamel.yaml.RoundTripDumper) + + +class WaypointRecorder: + """ + Class to record waypoints and clutter targets for the Spot robot + + How to use: + 1. Create an instance of this class + 2. Call the record_nav_target method with the nav_target_name as an argument (str) + 3. Call the record_clutter_target method with the clutter_target_name as an argument (str) + 4. Call the record_place_target method with the place_target_name as an argument (str) + 5. Call the save_yaml method to save the waypoints to the yaml file + + + Args: + spot (Spot): Spot robot object + waypoint_file_path (str): path to yaml file to save waypoints to + + + Example: + waypoint_recorder = WaypointRecorder(spot=Spot) + waypoint_recorder.record_nav_target("test_nav_target") + waypoint_recorder.record_clutter_target("test_clutter_target") + waypoint_recorder.record_place_target("test_place_target") + waypoint_recorder.save_yaml() + """ + + def __init__(self, spot: Spot, waypoint_file_path: str = WAYPOINT_YAML): + self.spot = spot + + # Local copy of waypoints.yaml which keeps getting updated as new waypoints are added + self.waypoint_file = waypoint_file_path + self.yaml_handler = YamlHandler() + self.yaml_dict = {} # type: Dict + + def init_yaml(self): + """ + Initialize member variable `self.yaml_dict` with the contents of the yaml file as a dict if it is not initialized. + """ + if self.yaml_dict == {}: + self.yaml_dict = self.yaml_handler.read_yaml( + waypoint_file=self.waypoint_file + ) + + def save_yaml(self): + """ + Save the waypoints (self.yaml_dict) to the yaml file if it is not empty. + It will overwrite the existing yaml file if it exists and will create a new one if it does not exist. + """ + if self.yaml_dict == {}: + print("No waypoints to save. Exiting...") + return + + self.yaml_handler.write_yaml(self.waypoint_file, self.yaml_dict) + print( + f"Successfully saved(/overwrote) all waypoints to file at {self.waypoint_file}:\n" + ) + + def unmark_clutter(self, clutter_target_name: str): + """ + INTERNAL METHOD: + Unmark a waypoint as clutter if it is already marked. + + It is used internally by the `record_nav_target` method to unmark a waypoint as clutter if it is already marked. + This is done to avoid cluttering the yaml file with duplicate clutter targets and also to update the waypoints.yaml + file if a previously marked "clutter" is not marked as a "nav_target" anymore. + + Args: + clutter_target_name (str): name of the waypoint to be unmarked as clutter + """ + # Add clutter list if not present + if "clutter" not in self.yaml_dict: + self.yaml_dict["clutter"] = [] + # Remove waypoint from clutter list if it exists + elif clutter_target_name in self.yaml_dict.get("clutter"): + print(f"Unmarking {clutter_target_name} from clutter list") + self.yaml_dict.get("clutter").remove(clutter_target_name) + + def mark_clutter(self, clutter_target_name: str): + """ + INTERNAL METHOD: + Mark a waypoint as clutter if it is not already marked. + + It is used internally by the `record_clutter_target` method to mark a waypoint as clutter if it is not already marked. + + Args: + clutter_target_name (str): name of the waypoint to be marked as clutter + """ + # Add clutter list if not present + if "clutter" not in self.yaml_dict: + self.yaml_dict["clutter"] = [] + + # Add waypoint as clutter if it does not exist + if clutter_target_name not in self.yaml_dict.get("clutter"): + print(f"Marking {clutter_target_name} in clutter list") + self.yaml_dict.get("clutter").append(clutter_target_name) + + def record_nav_target(self, nav_target_name: str): + """ + Record a waypoint as a nav target + + It will also unmark the waypoint as clutter if it is already marked as clutter. + If "nav_targets" does not exist, it will create a new "nav_targets" list initialized with the default "dock" waypoint + and add the new waypoint to it. + If the waypoint already exists, it will overwrite the existing waypoint data. + + Args: + nav_target_name (str): name of the waypoint to be recorded as a nav target + """ + # Initialize yaml_dict + self.init_yaml() + + # Get current nav pose + x, y, yaw = self.spot.get_xy_yaw() + yaw_deg = np.rad2deg(yaw) + nav_target = [float(x), float(y), float(yaw_deg)] + + # Unmark waypoint as clutter if it is already marked + self.unmark_clutter(clutter_target_name=nav_target_name) + + # Add nav_targets list if not present + if "nav_targets" not in self.yaml_dict: + self.yaml_dict["nav_targets"] = { + "dock": "[1.5, 0.0, 0.0]", + } + + # Erase existing waypoint data if present + if nav_target_name in self.yaml_dict.get("nav_targets"): + print( + f"Nav target for {nav_target_name} already exists as follows inside waypoints.yaml and will be overwritten." + ) + print( + f"old waypoint : {self.yaml_dict.get('nav_targets').get(nav_target_name)}" + ) + input("Press Enter if you want to continue...") + + # Add new waypoint data + self.yaml_dict.get("nav_targets").update({nav_target_name: nav_target}) + + def record_clutter_target(self, clutter_target_name: str): + """ + Record a waypoint as a clutter target + + It will initialize the member variable `self.yaml_dict` with appropriate content. + It will mark the waypoint as nav target, thereby also clearing it from the clutter list if it is already marked as clutter + It will mark the waypoint as clutter if not done already. + It will add the waypoint to the "clutter_amounts" list if it does not exist, and will update its value to 1. + + Args: + clutter_target_name (str): name of the waypoint to be recorded as a clutter target + """ + # Initialize yaml_dict + self.init_yaml() + + self.record_nav_target(clutter_target_name) + + # Mark waypoint as clutter + self.mark_clutter(clutter_target_name=clutter_target_name) + + # Add clutter_amounts list if not present + if "clutter_amounts" not in self.yaml_dict: + self.yaml_dict["clutter_amounts"] = {} + + # Add waypoint as clutter_amounts if it does not exist + if clutter_target_name not in self.yaml_dict.get("clutter_amounts"): + self.yaml_dict["clutter_amounts"].update({clutter_target_name: 1}) + print( + f"Added {clutter_target_name} in 'clutter_amounts' => ({clutter_target_name}:{self.yaml_dict.get('clutter_amounts').get(clutter_target_name)})" + ) + else: + print( + f"{clutter_target_name} already exists in 'clutter_amounts' => ({clutter_target_name}:{self.yaml_dict.get('clutter_amounts').get(clutter_target_name)})" + ) + + def record_place_target(self, place_target_name: str): + """ + Record a waypoint as a place target + + It will initialize the member variable `self.yaml_dict` with appropriate content + It will mark the waypoint as nav target, thereby also clearing it from the clutter list if it is already marked as clutter + It will add the waypoint to "place_targets" list if it does not exist, and will update its value to the current gripper position. + + Args: + place_target_name (str): name of the waypoint to be recorded as a place target + """ + # Initialize yaml_dict + self.init_yaml() + + self.record_nav_target(place_target_name) + + # Get place target as current gripper position + place_target = get_global_place_target(self.spot) + + # Add place_targets list if not present + if "place_targets" not in self.yaml_dict: + self.yaml_dict["place_targets"] = {} + + # Erase existing waypoint data if present + if place_target_name in self.yaml_dict.get("place_targets"): + print( + f"Place target for {place_target_name} already exists as follows inside waypoints.yaml and will be overwritten." + ) + print( + f"old waypoint : {self.yaml_dict.get('place_targets').get(place_target_name)}" + ) + + # Add new place target data + self.yaml_dict.get("place_targets").update({place_target_name: [*place_target]}) + + +def main(spot: Spot): + args = parse_arguments(args=sys.argv[1:]) + arg_bools = [args.clutter, args.place_target, args.nav_only, args.create_file] + assert ( + len([i for i in arg_bools if i]) == 1 + ), "Must pass in either -c, -p, -n, or -x as an arg, and not more than one." + + # Create WaypointRecorder object with default waypoint file + waypoint_recorder = WaypointRecorder(spot=spot) + + if args.create_file: + waypoint_recorder.init_yaml() + elif args.nav_only: + waypoint_recorder.record_nav_target(args.nav_only) + elif args.clutter: + waypoint_recorder.record_clutter_target(args.clutter) + elif args.place_target: + waypoint_recorder.record_place_target(args.place_target) + else: + raise NotImplementedError + + waypoint_recorder.save_yaml() + + +if __name__ == "__main__": + spot = Spot("WaypointRecorder") + main(spot) diff --git a/spot_rl_experiments/spot_rl/utils/whisper_translator.py b/spot_rl_experiments/spot_rl/utils/whisper_translator.py new file mode 100644 index 00000000..c6540977 --- /dev/null +++ b/spot_rl_experiments/spot_rl/utils/whisper_translator.py @@ -0,0 +1,157 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import os +import queue +from collections import deque + +import numpy as np +import openai +import sounddevice as sd +import soundfile as sf +import webrtcvad +import whisper + +openai.api_key = os.environ["OPENAI_API_KEY"] + + +class WhisperTranslator: + def __init__(self): + print("\n=====================================") + print("Initializing Whisper Translator") + self.filename = "data/temp_recordings/output.wav" + if not os.path.exists(os.path.dirname(self.filename)): + os.makedirs(os.path.dirname(self.filename)) + + self.sample_rate = 48000 + self.channels = 1 + self.device = self.identify_device("USB Microphone") + + # We record 30 ms of audio at a time + self.block_duration = 30 + self.blocksize = int(self.sample_rate * self.block_duration / 1000) + + # We process 50 chunks of 30 ms each to determine if someone is talking + self.speech_chunk_size = 50 + + # We record at least 4.5 seconds of audio at the beginning + self.minimum_recorded_time = 150 # (150 * 30 ms = 4.5 seconds) + + # If the rolling mean of the speech queue is below this threshold, we stop recording + self.silence_threshold = 0.15 + + # Queue to store speech chunks + self.speech_queue = deque(maxlen=int(self.speech_chunk_size)) + self.q = queue.Queue() + + # Voice Activity Detection + self.vad = webrtcvad.Vad() + self.vad.set_mode(3) + print("=====================================\n") + + def record(self): + """ + Records audio from the microphone, translates it to text and returns the text + """ + + def callback(indata, frames, time, status): + self.q.put(indata.copy()) + + print("Starting Recording") + iters = 0 + with sf.SoundFile( + self.filename, mode="w", samplerate=self.sample_rate, channels=self.channels + ) as f: + with sd.InputStream( + samplerate=self.sample_rate, + channels=self.channels, + device=self.device, + blocksize=self.blocksize, + callback=callback, + ): + while True: + data = self.q.get() + data = np.array(data * 32767, dtype=np.int16) + + # Check if someone is talking in the chunk (Voice Activity Detection) + is_speech = self.vad.is_speech(data.tobytes(), self.sample_rate) + self.speech_queue.append(is_speech) + rolling_mean = sum(self.speech_queue) / self.speech_chunk_size + + if iters > self.minimum_recorded_time: + if rolling_mean < self.silence_threshold: + print("Recording Ended - no voice activity in 1.5 seconds") + break + + f.write(data) + iters += 1 + print("Done Recording") + + def translate(self): + """ + Translates the audio to text using Whisper first from OPENAI CLOUD client and if it fails, then from locally downloaded model + """ + transcript = "default" + try: + with open(self.filename, "rb") as f: + result = openai.Audio.transcribe("whisper-1", f) + transcript = result["text"] + except Exception as e_cloud: + print( + "Error occured while inferencing Whisper from OpenAI CLOUD client: \n", + e_cloud, + ) + + try: + whisper_model = whisper.load_model("base", device="cuda") + audio = whisper.load_audio(self.filename) + audio = whisper.pad_or_trim(audio) + + # make log-Mel spectrogram and move to the same device as the model + mel = whisper.log_mel_spectrogram(audio).to(whisper_model.device) + + # detect the spoken language + _, probs = whisper_model.detect_language(mel) + print(f"Detected language: {max(probs, key=probs.get)}") + + # decode the audio + options = whisper.DecodingOptions() + result = whisper.decode(whisper_model, mel, options) + + # get the transcript out of whisper's decoded result + transcript = result.text + except Exception as e_local: + print( + "Error occured while inferencing Whisper from OpenAI LOCAL client: \n", + e_local, + ) + return transcript + + def identify_device(self, device_name="USB Microphone"): + """ + Identify the device number of the USB Microphone and returns it + """ + device_list = sd.query_devices() + devices = [ + (i, x["name"]) + for i, x in enumerate(device_list) + if device_name in x["name"] + ] + if len(devices) == 0: + print("USB Microphone not found. Using default device") + device_id = 0 + else: + print("Found following devices with name USB Microphone:\n", devices) + if len(devices) > 1: + print("Using first device from the list") + device_id = devices[0][0] + return device_id + + +if __name__ == "__main__": + wt = WhisperTranslator() + wt.record() + translation = wt.translate() + print(translation) diff --git a/third_party/DeblurGANv2 b/third_party/DeblurGANv2 new file mode 160000 index 00000000..b55888a5 --- /dev/null +++ b/third_party/DeblurGANv2 @@ -0,0 +1 @@ +Subproject commit b55888a56bb21d58f0be349f12de604bf3d2752b diff --git a/third_party/habitat-lab b/third_party/habitat-lab new file mode 160000 index 00000000..e6d788b3 --- /dev/null +++ b/third_party/habitat-lab @@ -0,0 +1 @@ +Subproject commit e6d788b35ec7df879deb94ece327921d51f2b9a8 diff --git a/third_party/mask_rcnn_detectron2 b/third_party/mask_rcnn_detectron2 new file mode 160000 index 00000000..19141f5b --- /dev/null +++ b/third_party/mask_rcnn_detectron2 @@ -0,0 +1 @@ +Subproject commit 19141f5b1b77513d3fab5406c72b28ecc307e7be