Skip to content

Commit

Permalink
fix bug with to_tensor calls
Browse files Browse the repository at this point in the history
  • Loading branch information
StoneT2000 committed Jan 30, 2024
1 parent 29da919 commit 756dd0f
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 16 deletions.
11 changes: 3 additions & 8 deletions mani_skill2/utils/sapien_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,7 @@

def to_tensor(array: Union[torch.Tensor, np.array, Sequence]):
"""
Maps any given sequence to the appropriate tensor for the appropriate backend
Note that all torch tensors are always moved to the GPU. There is generally no reason for them to ever be on the CPU as
GPU simulation is the only time torch is used in ManiSkill and thus all tensors from the simulation are on cuda devices
Maps any given sequence to a torch tensor on the CPU/GPU. If physx gpu is not enabled then we use CPU, otherwise GPU.
"""
if get_backend_name() == "torch":
if isinstance(array, np.ndarray):
Expand All @@ -37,11 +34,9 @@ def to_tensor(array: Union[torch.Tensor, np.array, Sequence]):
return torch.Tensor(array).cuda()
elif get_backend_name() == "numpy":
if isinstance(array, np.ndarray):
return array
elif isinstance(array, torch.Tensor):
return array.cpu().numpy()
return torch.from_numpy(array)
else:
return np.array(array)
return torch.tensor(array)


def to_numpy(array: Union[Array, Sequence]):
Expand Down
6 changes: 3 additions & 3 deletions mani_skill2/utils/structs/articulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,8 @@ def qf(self):

@qf.setter
def qf(self, arg1):
arg1 = to_tensor(arg1)
if physx.is_gpu_enabled():
arg1 = to_tensor(arg1)
self.px.cuda_articulation_qf[self._data_index, : self.dof] = arg1
else:
self._objs[0].qf = arg1
Expand All @@ -294,8 +294,8 @@ def qpos(self):

@qpos.setter
def qpos(self, arg1):
arg1 = to_tensor(arg1)
if physx.is_gpu_enabled():
arg1 = to_tensor(arg1)
self.px.cuda_articulation_qpos[self._data_index, : self.dof] = arg1
else:
self._objs[0].qpos = arg1
Expand All @@ -309,8 +309,8 @@ def qvel(self):

@qvel.setter
def qvel(self, arg1):
arg1 = to_tensor(arg1)
if physx.is_gpu_enabled():
arg1 = to_tensor(arg1)
self.px.cuda_articulation_qvel[self._data_index, : self.dof] = arg1
else:
self._objs[0].qvel = arg1
Expand Down
4 changes: 2 additions & 2 deletions mani_skill2/utils/structs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,8 @@ def angular_velocity(self) -> torch.Tensor:

@angular_velocity.setter
def angular_velocity(self, arg1: Array):
arg1 = to_tensor(arg1)
if physx.is_gpu_enabled():
arg1 = to_tensor(arg1)
self._body_data[self._body_data_index, 10:13] = arg1
else:
self._bodies[0].angular_velocity = arg1
Expand Down Expand Up @@ -292,8 +292,8 @@ def linear_velocity(self) -> torch.Tensor:

@linear_velocity.setter
def linear_velocity(self, arg1: Array):
arg1 = to_tensor(arg1)
if physx.is_gpu_enabled():
arg1 = to_tensor(arg1)
self._body_data[self._body_data_index, 7:10] = arg1
else:
self._bodies[0].linear_velocity = arg1
Expand Down
5 changes: 2 additions & 3 deletions mani_skill2/utils/structs/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,8 @@ def drive_target(self) -> torch.Tensor:

@drive_target.setter
def drive_target(self, arg1: Array) -> None:
arg1 = to_tensor(arg1)
if physx.is_gpu_enabled():
arg1 = to_tensor(arg1)
raise NotImplementedError(
"Setting drive targets of individual joints is not implemented yet."
)
Expand All @@ -210,8 +210,8 @@ def drive_velocity_target(self) -> torch.Tensor:

@drive_velocity_target.setter
def drive_velocity_target(self, arg1: Array) -> None:
arg1 = to_tensor(arg1)
if physx.is_gpu_enabled():
arg1 = to_tensor(arg1)
raise NotImplementedError(
"Cannot set drive velocity targets at the moment in GPU simulation"
)
Expand Down Expand Up @@ -248,7 +248,6 @@ def limits(self) -> torch.Tensor:

@limits.setter
def limits(self, arg1: Array) -> None:
arg1 = to_tensor(arg1)
for joint in self._objs:
joint.limits = arg1

Expand Down

0 comments on commit 756dd0f

Please sign in to comment.