Skip to content

Commit

Permalink
merge code for new assets, floors etc.
Browse files Browse the repository at this point in the history
  • Loading branch information
StoneT2000 committed Dec 14, 2023
1 parent 8f4f6b1 commit 771c27f
Show file tree
Hide file tree
Showing 30 changed files with 535 additions and 86 deletions.
6 changes: 4 additions & 2 deletions mani_skill2/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,11 @@ def _load_articulation(self):
self.robot_link_ids = [link.name for link in self.robot.get_links()]

def _after_loading_articulation(self):
"""After loading articulation and before setting up controller. Not recommended, but is useful for when creating
robot classes that inherit controllers from another and only change which joints are controlled"""
"""After loading articulation and before setting up controller. Not recommended, but is useful for when creating
robot classes that inherit controllers from another and only change which joints are controlled
"""
pass

def _after_init(self):
"""After initialization. E.g., caching the end-effector link."""
pass
Expand Down
3 changes: 2 additions & 1 deletion mani_skill2/envs/misc/avoid_obstacles.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from mani_skill2.agents.robots.panda.panda import Panda
from mani_skill2.envs.sapien_env import BaseEnv
from mani_skill2.sensors.camera import CameraConfig
from mani_skill2.utils.building.ground import build_tesselated_square_floor
from mani_skill2.utils.io_utils import load_json
from mani_skill2.utils.registration import register_env
from mani_skill2.utils.sapien_utils import (
Expand Down Expand Up @@ -118,7 +119,7 @@ def _build_coord_frame_site(self, scale=0.1, name="coord_frame"):
return actor

def _load_actors(self):
self._add_ground(render=self.bg_name is None)
build_tesselated_square_floor(self._scene)

# Add a wall
if "wall" in self.episode_config:
Expand Down
3 changes: 2 additions & 1 deletion mani_skill2/envs/ms1/base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from mani_skill2.agents.robots.mobile_panda import DummyMobileAgent
from mani_skill2.envs.sapien_env import BaseEnv
from mani_skill2.sensors.camera import CameraConfig
from mani_skill2.utils.building.ground import build_tesselated_square_floor
from mani_skill2.utils.common import random_choice
from mani_skill2.utils.io_utils import load_json
from mani_skill2.utils.sapien_utils import (
Expand Down Expand Up @@ -110,7 +111,7 @@ def _set_model(self, model_id):

def _load_actors(self):
# Create a collision ground plane
ground = self._add_ground(render=True)
ground = build_tesselated_square_floor(self._scene)
# TODO (stao): This is quite hacky. Future we expect the robot to be an actual well defined robot without needing to intersect the ground. We should probably deprecate the old ms1 envs eventually
# Specify a collision (ignore) group to avoid collision with robot torso
cs = ground.find_component_by_type(
Expand Down
4 changes: 4 additions & 0 deletions mani_skill2/envs/sapien_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,7 @@ def _setup_sensors(self):
)

def _setup_lighting(self):
# TODO (stao): remove this code out. refactor it to be inside scene builders
"""Setup lighting in the scene. Called by `self.reconfigure`"""

shadow = self.enable_shadow
Expand Down Expand Up @@ -728,6 +729,9 @@ def render_human(self):
if self._viewer is None:
self._viewer = Viewer(self._renderer)
self._setup_viewer()
self._viewer.set_camera_pose(
self._render_cameras["render_camera"].camera.global_pose
)
self._viewer.render()
return self._viewer

Expand Down
49 changes: 49 additions & 0 deletions mani_skill2/envs/scenes/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Temporary Docs on environments with pre-built scenes

The code at the moment supports scenes created by [AI2-THOR](https://ai2thor.allenai.org/) stored in this Hugging Face Dataset: https://huggingface.co/datasets/hssd/ai2thor-hab/tree/main

Other scenes from other groups/projects are supportable provided you write an `SceneAdapter` class to load that type of configuration. See `mani_skill2/envs/scenes/adapters/hssd` for one such example.

To download ai2thor-hab scene dataset, in the root level run

```bash
# note that this is a very big dataset (~20GB) with 50,000+ files so it can take some time to download.
python mani_skill2/envs/scenes/adapters/hssd/download.py <your_hugging_face_access_token>
```

which downloads all the AI2-THOR scenes to `data/scene_datasets/ai2thor`

A dummy environment that uses just the ArchitecTHOR set of scenes (high quality human built scenes from AI2) can be created via standard gym as so

```python
import mani_skill2.envs
import gymnasium as gym
# PickObjectScene-v0 selects a scene randomly from the selected scene datasets and
# instantiates a robot randomly and selects a random object for the robot to find and pick up.
# render_mode="human" opens up a viewer, convex_decomposition="none" makes scene loading fast (but not well simulated)
# set convex_decomposition="coacd" to use CoACD to get better collision meshes
import sapien.render
env = gym.make("PickObjectScene-v0", scene_datasets=["ArchitecTHOR"], render_mode="human", convex_decomposition="none", fixed_scene=True)

# optionally set these to make it more realistic
sapien.render.set_camera_shader_dir("rt")
sapien.render.set_viewer_shader_dir("rt")
sapien.render.set_ray_tracing_samples_per_pixel(4)
sapien.render.set_ray_tracing_path_depth(2)
sapien.render.set_ray_tracing_denoiser("optix")

env.reset(seed=2, options=dict(reconfigure=True))

while True:
env.render()
```

Note that we set `fixed_scene=True` which is the default option. This means all calls to env.reset(seed=seed) or just env.reset() will always use the same scene. To change scene simply call `env.reset(seed=seed, options=dict(reconfigure=True))` which will always create the same scene depending on the seed here. When `fixed_scene=False` then every call to env.reset will create a new scene and the seed will dictate which scene is created.


## TODOs

- Find a better location to save metadata for scene datasets
- Pick reasonable initial states for robots in scenes that don't collide with the scene.
- More intuitive API for scene setting?
- Add in mobile robots
6 changes: 5 additions & 1 deletion mani_skill2/envs/scenes/base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,11 @@ def reset(self, seed=None, options=None):
return super().reset(seed, options)

def _load_actors(self):
self.scene_builder.build(self._scene, scene_id=self.sampled_scene_idx, convex_decomposition=self.convex_decomposition)
self.scene_builder.build(
self._scene,
scene_id=self.sampled_scene_idx,
convex_decomposition=self.convex_decomposition,
)

def _initialize_agent(self):
if self.robot_uid == "panda":
Expand Down
1 change: 1 addition & 0 deletions mani_skill2/examples/demo_manual_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def render_wait():
if not args.enable_sapien_viewer:
return
while True:
env.render_human()
sapien_viewer = env.viewer
if sapien_viewer.window.key_down("0"):
break
Expand Down
7 changes: 4 additions & 3 deletions mani_skill2/examples/demo_scenes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
`python -m mani_skill2.examples.demo_scenes` and explore around.
"""
import gymnasium as gym
import sapien.render
import numpy as np
import sapien.render

import mani_skill2.envs
from mani_skill2.utils.scene_builder.ai2thor import (
ArchitecTHORSceneBuilder,
Expand All @@ -41,8 +42,8 @@
)

# optionally set these to make it more realistic
sapien.render.set_camera_shader_dir("rt")
sapien.render.set_viewer_shader_dir("rt")
sapien.render.set_camera_shader_dir("rt2")
sapien.render.set_viewer_shader_dir("rt2")
sapien.render.set_ray_tracing_samples_per_pixel(4)
sapien.render.set_ray_tracing_path_depth(2)
sapien.render.set_ray_tracing_denoiser("optix")
Expand Down
34 changes: 26 additions & 8 deletions mani_skill2/utils/building/actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,9 @@ def build_fourcolor_peg(
dynamic: bool = True,
add_collision: bool = True,
):
"""
A peg with four sections and four different colors. Useful for visualizing every possible rotation without any symmetries
"""
builder = scene.create_actor_builder()
if add_collision:
builder.add_box_collision(
Expand Down Expand Up @@ -200,6 +203,9 @@ def build_actor(model_id: str, scene: sapien.Scene, name: str):
pass


### YCB Dataset ###


def build_actor_ycb(
model_id: str,
scene: sapien.Scene,
Expand Down Expand Up @@ -241,35 +247,47 @@ def _load_ycb_dataset():
}


### AI2THOR Object Dataset ###


def build_actor_ai2(
model_id: str,
scene: sapien.Scene,
name: str,
kinematic: bool = False,
set_object_on_ground=True,
):
"""
Builds an actor/object from the AI2THOR assets.
TODO (stao): Automatically makes the origin of the object be the center of the object.
set_object_on_ground: bool
if True, will set the pose of the created actor automatically so that the lowest point of the actor is at z = 0
"""
model_path = (
Path(ASSET_DIR)
/ "scene_datasets/ai2thor/ai2thorhab-uncompressed/assets/objects"
/ f"{model_id}.glb"
)
actor_id = name
builder = scene.create_actor_builder()
builder.add_visual_from_file(str(model_path))
q = transforms3d.quaternions.axangle2quat(np.array([1, 0, 0]), theta=np.deg2rad(90))
pose = sapien.Pose(q=q)
builder.add_visual_from_file(str(model_path), pose=pose)
if kinematic:
builder.add_nonconvex_collision_from_file(str(model_path))
builder.add_nonconvex_collision_from_file(str(model_path), pose=pose)
actor = builder.build_kinematic(name=actor_id)
else:
builder.add_multiple_convex_collisions_from_file(
str(model_path), decomposition="coacd"
str(model_path), decomposition="coacd", pose=pose
)
actor = builder.build(name=actor_id)

q = transforms3d.quaternions.axangle2quat(np.array([1, 0, 0]), theta=np.deg2rad(90))
pose = sapien.Pose(q=q)
actor.set_pose(pose)
aabb = actor.find_component_by_type(
sapien.render.RenderBodyComponent
).compute_global_aabb_tight()
height = aabb[1, 1] - aabb[0, 1]
actor.set_pose(sapien.Pose(p=[0, 0, 0], q=actor.pose.q))
height = aabb[1, 2] - aabb[0, 2]
if set_object_on_ground:
actor.set_pose(sapien.Pose(p=[0, 0, 0]))
return actor
44 changes: 29 additions & 15 deletions mani_skill2/utils/building/articulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,21 +47,44 @@ class ArticulationMetadata:
movable_links: List[str]


model_dbs: Dict[str, Dict[str, Dict]] = {}
def build_articulation_from_file(
scene: sapien.Scene,
urdf_path: str,
fix_root_link=True,
scale: float = 1.0,
decomposition="none",
set_object_on_ground=True,
):
loader = scene.create_urdf_loader()
loader.multiple_collisions_decomposition = decomposition
loader.fix_root_link = fix_root_link
loader.scale = scale
loader.load_multiple_collisions_from_file = True
articulation: physx.PhysxArticulation = loader.load(urdf_path)
articulation.set_qpos(articulation.qpos)
bounds = merge_meshes(get_articulation_meshes(articulation)).bounds
if set_object_on_ground:
articulation.set_pose(Pose([0, 0, -bounds[0, 2]]))
return articulation, bounds

# TODO optimization: we can cache some results in building articulations and reuse them

# cache model metadata here if needed
model_dbs: Dict[str, Dict[str, Dict]] = {}


### Build articulations ###
def build_partnet_mobility_articulation(
def build_preprocessed_partnet_mobility_articulation(
scene: sapien.Scene,
model_id: str,
fix_root_link=True,
urdf_config: dict = None,
set_object_on_ground=True,
):
"""
Builds a physx.PhysxArticulation object into the scene and returns metadata containing annotations of the object's links and joints
Builds a physx.PhysxArticulation object into the scene and returns metadata containing annotations of the object's links and joints.
This uses preprocessed data from the ManiSkill team where assets were annotated with correct scales and provided
proper convex decompositions of the articulations.
Args:
scene: the sapien scene to add articulation to
Expand All @@ -85,9 +108,6 @@ def build_partnet_mobility_articulation(
articulation: physx.PhysxArticulation = loader.load(str(urdf_path))

metadata = ArticulationMetadata(joints=dict(), links=dict(), movable_links=[])
target_links = []
target_joints = []
target_handles = []

# NOTE(jigu): links and their parent joints.
for link, joint in zip(articulation.get_links(), articulation.get_joints()):
Expand Down Expand Up @@ -122,22 +142,15 @@ def build_partnet_mobility_articulation(
bounds = merge_meshes(get_articulation_meshes(articulation)).bounds
articulation.set_pose(Pose([0, 0, -bounds[0, 2]]))

# for link in self.articulation_metadata.movable_links:
# link_name = self.articulation_metadata.links[link].name
# b = self.articulation_metadata.joints[f"joint_{link.split('_')[1]}"].type
# c = self.articulation_metadata.joints[f"joint_{link.split('_')[1]}"].name
# print(link, link_name, b, c)

return articulation, metadata


def _load_partnet_mobility_dataset():
# load PartnetMobility
"""loads preprocssed partnet mobility metadata"""
model_dbs["PartnetMobility"] = {
"model_data": load_json(
PACKAGE_ASSET_DIR / "partnet_mobility/meta/info_cabinet_drawer_train.json"
),
"builder": build_partnet_mobility_articulation,
}

def find_urdf_path(model_id):
Expand All @@ -147,6 +160,7 @@ def find_urdf_path(model_id):
urdf_path = model_dir / urdf_name
if urdf_path.exists():
return urdf_path

model_dbs["PartnetMobility"]["model_urdf_paths"] = {
k: find_urdf_path(k) for k in model_dbs["PartnetMobility"]["model_data"].keys()
}
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 3 additions & 3 deletions mani_skill2/utils/building/assets/tiled_floor.obj
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ v -1.000000 0.000000 -1.000000
v 1.000000 0.000000 -1.000000
vn -0.0000 1.0000 -0.0000
vt 0.000000 0.000000
vt 9.223548 0.000000
vt 9.223548 9.250000
vt 0.000000 9.250000
vt 20 0.000000
vt 20 20
vt 0.000000 20
s 0
f 1/1/1 2/2/1 4/3/1 3/4/1
7 changes: 3 additions & 4 deletions mani_skill2/utils/building/ground.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,21 @@
import sapien.render


def build_tesselated_square_floor(scene: sapien.Scene, altitude=0):
def build_tesselated_square_floor(scene: sapien.Scene, floor_width=20, altitude=0):
ground = scene.create_actor_builder()
render_half_size = 10
rend_mtl = sapien.render.RenderMaterial(
base_color=[0.06, 0.08, 0.12, 1],
metallic=0.0,
roughness=0.9,
specular=0.8,
)
rend_mtl.diffuse_texture = sapien.render.RenderTexture2D(
osp.join(osp.dirname(__file__), "assets/floor_texture.png")
osp.join(osp.dirname(__file__), "assets/floor_texture_4.png")
)
ground.add_visual_from_file(
osp.join(osp.dirname(__file__), "assets/tiled_floor.obj"),
pose=sapien.Pose(p=[0, 0, altitude], q=[0.7071068, 0.7071068, 0, 0]),
scale=[render_half_size, 1, render_half_size],
scale=[floor_width, 1, floor_width],
material=rend_mtl,
)
ground.add_plane_collision(
Expand Down
8 changes: 0 additions & 8 deletions mani_skill2/utils/building/randomization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,3 @@ def random_quaternion(rng: np.random.RandomState):
# Uniform sample a quaternion
q = transforms3d.quaternions.axangle2quat(rng.rand(3), 2 * np.pi * rng.rand())
return q


def random_position(rng: np.random.RandomState):
pass


def generate_random_reachable_position(rng: np.random.RandomState, source, max_dist):
pass
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Run python mani_skill2/envs/scenes/adapters/hssd/download.py <HF_TOKEN>
Run python -m mani_skill2.utils.scene_builder.ai2thor.download <HF_TOKEN>
"""

import os.path as osp
Expand Down
1 change: 1 addition & 0 deletions mani_skill2/utils/scene_builder/kitchen/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .kitchen_scene_builder import KitchenSceneBuilder
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 771c27f

Please sign in to comment.