diff --git a/bitbots_misc/bitbots_utils/bitbots_utils/utils.py b/bitbots_misc/bitbots_utils/bitbots_utils/utils.py index 74f8d356e..eed87614a 100644 --- a/bitbots_misc/bitbots_utils/bitbots_utils/utils.py +++ b/bitbots_misc/bitbots_utils/bitbots_utils/utils.py @@ -158,7 +158,7 @@ def set_parameters_of_other_node( return [res.success for res in response.results] -def parse_parameter_dict(*, namespace, parameter_dict): +def parse_parameter_dict(*, namespace, parameter_dict) -> list[ParameterMsg]: parameters = [] for param_name, param_value in parameter_dict.items(): full_param_name = namespace + param_name diff --git a/bitbots_motion/bitbots_quintic_walk/bitbots_quintic_walk_py/py_walk.py b/bitbots_motion/bitbots_quintic_walk/bitbots_quintic_walk_py/py_walk.py index 48d51ea67..f3b4bc30c 100644 --- a/bitbots_motion/bitbots_quintic_walk/bitbots_quintic_walk_py/py_walk.py +++ b/bitbots_motion/bitbots_quintic_walk/bitbots_quintic_walk_py/py_walk.py @@ -1,3 +1,5 @@ +from typing import Optional + from biped_interfaces.msg import Phase from bitbots_quintic_walk_py.libpy_quintic_walk import PyWalkWrapper from bitbots_utils.utils import parse_parameter_dict @@ -11,17 +13,24 @@ class PyWalk: - def __init__(self, namespace="", parameters: [Parameter] | None = None, set_force_smooth_step_transition=False): - serialized_parameters = [] - if parameters is not None: - for parameter in parameters: - serialized_parameters.append(serialize_message(parameter)) - if parameter.value.type == 2: - print( - f"Gave parameter {parameter.name} of integer type. If the code crashes it is maybe because this " - f"should be a float. You may need to add an .0 in some yaml file." - ) - self.py_walk_wrapper = PyWalkWrapper(namespace, serialized_parameters, set_force_smooth_step_transition) + def __init__( + self, + namespace="", + walk_parameters: Optional[list[Parameter]] = None, + moveit_parameters: Optional[list[Parameter]] = None, + set_force_smooth_step_transition=False, + ): + def serialize_parameters(parameters): + if parameters is None: + return [] + return list(map(serialize_message, parameters)) + + self.py_walk_wrapper = PyWalkWrapper( + namespace, + serialize_parameters(walk_parameters), + serialize_parameters(moveit_parameters), + set_force_smooth_step_transition, + ) def spin_ros(self): self.py_walk_wrapper.spin_some() @@ -98,24 +107,24 @@ def set_parameters(self, param_dict): for parameter in parameters: self.py_walk_wrapper.set_parameter(serialize_message(parameter)) - def get_phase(self): + def get_phase(self) -> float: return self.py_walk_wrapper.get_phase() - def get_freq(self): + def get_freq(self) -> float: return self.py_walk_wrapper.get_freq() - def get_support_state(self): + def get_support_state(self) -> Phase: return deserialize_message(self.py_walk_wrapper.get_support_state(), Phase) - def is_left_support(self): + def is_left_support(self) -> bool: return self.py_walk_wrapper.is_left_support() - def get_odom(self): + def get_odom(self) -> Odometry: odom = self.py_walk_wrapper.get_odom() result = deserialize_message(odom, Odometry) return result - def publish_debug(self): + def publish_debug(self) -> None: self.py_walk_wrapper.publish_debug() def reset_and_test_if_speed_possible(self, cmd_vel_msg, threshold=0.001): diff --git a/bitbots_motion/bitbots_quintic_walk/include/bitbots_quintic_walk/walk_node.hpp b/bitbots_motion/bitbots_quintic_walk/include/bitbots_quintic_walk/walk_node.hpp index 057e65639..ad2da928c 100644 --- a/bitbots_motion/bitbots_quintic_walk/include/bitbots_quintic_walk/walk_node.hpp +++ b/bitbots_motion/bitbots_quintic_walk/include/bitbots_quintic_walk/walk_node.hpp @@ -54,7 +54,7 @@ namespace bitbots_quintic_walk { class WalkNode { public: explicit WalkNode(rclcpp::Node::SharedPtr node, const std::string &ns = "", - std::vector parameters = {}); + const std::vector &moveit_parameters = {}); bitbots_msgs::msg::JointCommand step(double dt); bitbots_msgs::msg::JointCommand step(double dt, geometry_msgs::msg::Twist::SharedPtr cmdvel_msg, sensor_msgs::msg::Imu::SharedPtr imu_msg, @@ -112,8 +112,6 @@ class WalkNode { nav_msgs::msg::Odometry getOdometry(); - rcl_interfaces::msg::SetParametersResult onSetParameters(const std::vector ¶meters); - void publish_debug(); rclcpp::TimerBase::SharedPtr startTimer(); double getTimerFreq(); diff --git a/bitbots_motion/bitbots_quintic_walk/include/bitbots_quintic_walk/walk_pywrapper.hpp b/bitbots_motion/bitbots_quintic_walk/include/bitbots_quintic_walk/walk_pywrapper.hpp index 48c2961e3..258689f77 100644 --- a/bitbots_motion/bitbots_quintic_walk/include/bitbots_quintic_walk/walk_pywrapper.hpp +++ b/bitbots_motion/bitbots_quintic_walk/include/bitbots_quintic_walk/walk_pywrapper.hpp @@ -23,7 +23,8 @@ using namespace ros2_python_extension; class PyWalkWrapper { public: - explicit PyWalkWrapper(std::string ns, std::vector parameter_msgs = {}, + explicit PyWalkWrapper(const std::string &ns, const std::vector &walk_parameter_msgs = {}, + const std::vector &moveit_parameter_msgs = {}, bool force_smooth_step_transition = false); py::bytes step(double dt, py::bytes &cmdvel_msg, py::bytes &imu_msg, py::bytes &jointstate_msg, py::bytes &pressure_left, py::bytes &pressure_right); diff --git a/bitbots_motion/bitbots_quintic_walk/src/walk_node.cpp b/bitbots_motion/bitbots_quintic_walk/src/walk_node.cpp index 5e220ee40..8f238f2ff 100644 --- a/bitbots_motion/bitbots_quintic_walk/src/walk_node.cpp +++ b/bitbots_motion/bitbots_quintic_walk/src/walk_node.cpp @@ -9,7 +9,8 @@ using namespace std::chrono_literals; namespace bitbots_quintic_walk { -WalkNode::WalkNode(rclcpp::Node::SharedPtr node, const std::string& ns, std::vector parameters) +WalkNode::WalkNode(rclcpp::Node::SharedPtr node, const std::string& ns, + const std::vector& moveit_parameters) : node_(node), param_listener_(node_), config_(param_listener_.get_params()), @@ -17,21 +18,15 @@ WalkNode::WalkNode(rclcpp::Node::SharedPtr node, const std::string& ns, std::vec stabilizer_(node_), ik_(node_, config_.node.ik), visualizer_(node_, config_.node.tf) { - // Create dummy node for moveit - auto moveit_node = std::make_shared(ns + "walking_moveit_node"); - - // when called from python, parameters are given to the constructor - for (auto parameter : parameters) { - if (node_->has_parameter(parameter.get_name())) { - // this is the case for walk engine params set via python - node_->set_parameter(parameter); - } else { - // parameter is not for the walking, set on moveit node - moveit_node->declare_parameter(parameter.get_name(), parameter.get_type()); - moveit_node->set_parameter(parameter); - } - } - + // Create dummy node for moveit. This is necessary for dynamic reconfigure to work (moveit does some bullshit with + // parameter declarations, so we need to isolate the walking parameters from the moveit parameters). + // If the walking is instantiated using the python wrapper, moveit parameters are passed because no moveit config + // is loaded in the conventional way. Normally the moveit config is loaded via launch file and the passed vector is + // empty. + auto moveit_node = std::make_shared( + "walking_moveit_node", ns, + rclcpp::NodeOptions().automatically_declare_parameters_from_overrides(true).parameter_overrides( + moveit_parameters)); // get all kinematics parameters from the move_group node if they are not set manually via constructor std::string check_kinematic_parameters; if (!moveit_node->get_parameter("robot_description_kinematics.LeftLeg.kinematics_solver", diff --git a/bitbots_motion/bitbots_quintic_walk/src/walk_pywrapper.cpp b/bitbots_motion/bitbots_quintic_walk/src/walk_pywrapper.cpp index f4587962f..3d83ae465 100644 --- a/bitbots_motion/bitbots_quintic_walk/src/walk_pywrapper.cpp +++ b/bitbots_motion/bitbots_quintic_walk/src/walk_pywrapper.cpp @@ -2,21 +2,37 @@ void PyWalkWrapper::spin_some() { rclcpp::spin_some(node_); } -PyWalkWrapper::PyWalkWrapper(std::string ns, std::vector parameter_msgs, bool force_smooth_step_transition) { +PyWalkWrapper::PyWalkWrapper(const std::string &ns, const std::vector &walk_parameter_msgs, + const std::vector &moveit_parameter_msgs, bool force_smooth_step_transition) { // initialize rclcpp if not already done if (!rclcpp::contexts::get_global_default_context()->is_valid()) { rclcpp::init(0, nullptr); } - // create parameters from serialized messages - std::vector cpp_parameters = {}; - for (auto ¶meter_msg : parameter_msgs) { - cpp_parameters.push_back( - rclcpp::Parameter::from_parameter_msg(fromPython(parameter_msg))); - } - - node_ = rclcpp::Node::make_shared(ns + "walking"); - walk_node_ = std::make_shared(node_, ns, cpp_parameters); + // internal function to deserialize the parameter messages + auto deserialize_parameters = [](std::vector parameter_msgs) { + std::vector cpp_parameters = {}; + for (auto ¶meter_msg : parameter_msgs) { + cpp_parameters.push_back( + rclcpp::Parameter::from_parameter_msg(fromPython(parameter_msg))); + } + return cpp_parameters; + }; + + // Create a node object + // Even tho we use python bindings instead of ros's dds, we still need a node object for logging and parameter + // handling Because the walking is not started using the launch infrastructure and an appropriate parameter file, we + // need to manually set the parameters + node_ = rclcpp::Node::make_shared( + "walking", ns, rclcpp::NodeOptions().parameter_overrides(deserialize_parameters(walk_parameter_msgs))); + + // Create the walking object + // We pass it the node we created. But the walking also creates a helper node for moveit (otherwise dynamic + // reconfigure does not work, because moveit does some bullshit with their parameter declarations leading dynamic + // reconfigure not working). This way the walking parameters are isolated from the moveit parameters, allowing dynamic + // reconfigure to work. Therefore we need to pass the moveit parameters to the walking. + walk_node_ = + std::make_shared(node_, ns, deserialize_parameters(moveit_parameter_msgs)); set_robot_state(0); walk_node_->initializeEngine(); walk_node_->getEngine()->setForceSmoothStepTransition(force_smooth_step_transition); @@ -197,7 +213,7 @@ PYBIND11_MODULE(libpy_quintic_walk, m) { using namespace bitbots_quintic_walk; py::class_>(m, "PyWalkWrapper") - .def(py::init, bool>()) + .def(py::init, std::vector, bool>()) .def("step", &PyWalkWrapper::step) .def("step_relative", &PyWalkWrapper::step_relative) .def("step_open_loop", &PyWalkWrapper::step_open_loop) diff --git a/bitbots_wolfgang/wolfgang_description/CMakeLists.txt b/bitbots_wolfgang/wolfgang_description/CMakeLists.txt index 7cfdeef37..968f275a2 100644 --- a/bitbots_wolfgang/wolfgang_description/CMakeLists.txt +++ b/bitbots_wolfgang/wolfgang_description/CMakeLists.txt @@ -59,5 +59,6 @@ ament_export_include_directories(${INCLUDE_DIRS}) install(DIRECTORY launch DESTINATION share/${PROJECT_NAME}) install(DIRECTORY urdf DESTINATION share/${PROJECT_NAME}) install(DIRECTORY config DESTINATION share/${PROJECT_NAME}) +install(DIRECTORY scripts DESTINATION share/${PROJECT_NAME}) ament_package() diff --git a/bitbots_wolfgang/wolfgang_description/scripts/urdf_to_mujoco.py b/bitbots_wolfgang/wolfgang_description/scripts/urdf_to_mujoco.py new file mode 100644 index 000000000..20fb26785 --- /dev/null +++ b/bitbots_wolfgang/wolfgang_description/scripts/urdf_to_mujoco.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python + +import os +import sys +import tempfile +import mujoco +from lxml import etree + +# This script converts a URDF file to a MuJoCo XML model file. + +INPUT_URDF = '/home/florian/Projekt/bitbots/bitbots_main/bitbots_wolfgang/wolfgang_description/urdf/robot.urdf' + +ACTUATOR_TYPES = ['mx106', 'mx64', 'xh-540'] + +ACTUATOR_MAPPING = { + 'HeadPan': 'mx64', + 'HeadTilt': 'mx64', + 'LShoulderPitch': 'mx106', + 'LShoulderRoll': 'mx106', + 'LElbow': 'mx64', + 'LAnklePitch': 'mx106', + 'LAnkleRoll': 'mx106', + 'LHipYaw': 'mx106', + 'LHipRoll': 'mx106', + 'LHipPitch': 'mx106', + 'LKnee': 'xh-540', + 'RShoulderPitch': 'mx106', + 'RShoulderRoll': 'mx106', + 'RElbow': 'mx64', + 'RAnklePitch': 'mx106', + 'RAnkleRoll': 'mx106', + 'RHipYaw': 'mx106', + 'RHipRoll': 'mx106', + 'RHipPitch': 'mx106', + 'RKnee': 'xh-540', +} # TODO verify the mapping + +ACTUATOR_DEFAULTS = { + 'mx106': { + "joint": { + "damping": "1.23", + "armature": "0.045", + "frictionloss": "2.55", + "limited": "true" + }, + "position": { # This is for the position controller + "kp": "21.1", + "ctrlrange": "-3.141592 3.141592", + "forcerange": "-5 5" + } + }, + 'mx64': { + "joint": { + "damping": "0.65", + "armature": "0.045", + "frictionloss": "1.73", + "limited": "true" + }, + "position": { # This is for the position controller + "kp": "21.1", + "ctrlrange": "-3.141592 3.141592", + "forcerange": "-5 5" + } + }, + 'xh-540': { + "joint": { + "damping": "2.92", + "armature": "0.045", + "frictionloss": "1.49", + "limited": "true" + }, + "position": { # This is for the position controller + "kp": "21.1", + "ctrlrange": "-3.141592 3.141592", + "forcerange": "-5 5" + } + } +} + + +def main(): + # Load the urdf file + if not os.path.exists(INPUT_URDF): + print('Error: URDF file not found:', INPUT_URDF) + sys.exit(1) + urdf_tree = etree.parse(INPUT_URDF) + + # Add the mujoco tag to the URDF file to make it compatible with mujoco + mujoco_tag = etree.Element('mujoco') + mujoco_tag.append(etree.Element('compiler', discardvisual='false', meshdir=os.path.dirname(INPUT_URDF))) + urdf_tree.getroot().append(mujoco_tag) + + # Render the URDF file as a string + urdf_string = etree.tostring(urdf_tree, pretty_print=True) + + # Load the URDF file into mujoco + model = mujoco.MjModel.from_xml_string(urdf_string) + + # Save model as XML (temporary file) + temp_xml_file_path = tempfile.mktemp() + mujoco.mj_saveLastXML(temp_xml_file_path, model) + + # Load the XML file into an etree + tree = etree.parse(temp_xml_file_path) + + # Apply some modifications / fixes + + # Move everything in the worldbody into a new body called torso (also add freejoint and light) # TODO investigate why the torso is not the root body + worldbody = tree.find('.//worldbody') + torso = etree.Element('body', name='torso', pos="0 0 0.4274", quat="0.999 0.0 0.05 0.0") # TODO check if pos is correct + for child in worldbody.getchildren(): + torso.append(child) + torso.append(etree.Element('freejoint')) + worldbody.clear() + worldbody.append(torso) + worldbody.append(etree.Element('light', name='spotlight', mode='targetbodycom', target='torso', pos='0 -1 2')) + + + # Assign classes to all geometries + # Find visual elements, meaning geometries with contype="0" conaffinity="0" group="1" density="0" + for geom in tree.findall('.//geom[@contype="0"][@conaffinity="0"][@group="1"][@density="0"]'): + # Remove the attributes + geom.attrib.pop('contype', None) + geom.attrib.pop('conaffinity', None) + geom.attrib.pop('group', None) + geom.attrib.pop('density', None) + + # Also remove the rgba attribute + geom.attrib.pop('rgba', None) + + # Assign the class attribute + geom.attrib['class'] = 'visual' + + # Find geometries that don't have a class yet and assign them to the collision class + for geom in tree.xpath('.//geom[not(@class)]'): + geom.attrib['class'] = 'collision' + + # Add defaults + + defaults = etree.fromstring(""" + + + + + + + + + + """) + + # Add actuator and joint defaults + assert set(ACTUATOR_DEFAULTS.keys()) == set(ACTUATOR_TYPES) + for actuator_type, actuator_defaults in ACTUATOR_DEFAULTS.items(): + default = etree.Element('default', **{'class': actuator_type}) + default.extend([ + etree.Element('joint', **actuator_defaults['joint']), + etree.Element('position', **actuator_defaults['position']) + ]) + defaults.append(default) + tree.getroot().insert(0, defaults) + + # Remove meshdir attribute from compiler tag + tree.find('.//compiler').attrib.pop('meshdir', None) + + # Add 'black' material to assets + tree.find('.//asset').append(etree.Element('material', name='black', rgba='0.2 0.2 0.2 1')) + + # Remove damping and frictionloss from all joints in the worldbody + for joint in tree.findall('.//worldbody/.//joint'): + # Remove the attributes + joint.attrib.pop('damping', None) + joint.attrib.pop('frictionloss', None) + joint.attrib.pop('limited', None) + # Add class based on the actuator type + joint.attrib['class'] = ACTUATOR_MAPPING[joint.attrib['name']] + + # Add actuators to the top level + actuator = etree.Element('actuator') + for joint in tree.findall('.//worldbody/.//joint'): + actuator.append(etree.Element('position', **{'joint': joint.attrib['name'], 'name': joint.attrib['name'], 'class': ACTUATOR_MAPPING[joint.attrib['name']] })) + tree.getroot().append(actuator) + + # Save the XML file with pretty formatting + output_filename = os.path.splitext(INPUT_URDF)[0] + '.xml' + tree.write(output_filename, pretty_print=True) + print('Saved MuJoCo model to', output_filename) + +if __name__ == '__main__': + main() diff --git a/bitbots_wolfgang/wolfgang_description/urdf/robot.xml b/bitbots_wolfgang/wolfgang_description/urdf/robot.xml new file mode 100644 index 000000000..398a3391e --- /dev/null +++ b/bitbots_wolfgang/wolfgang_description/urdf/robot.xml @@ -0,0 +1,612 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/bitbots_wolfgang/wolfgang_description/urdf/scene.xml b/bitbots_wolfgang/wolfgang_description/urdf/scene.xml new file mode 100644 index 000000000..7f77204d9 --- /dev/null +++ b/bitbots_wolfgang/wolfgang_description/urdf/scene.xml @@ -0,0 +1,23 @@ + + + + + + + + + + + + + + + + + + + + + + diff --git a/mujoco_env.py b/mujoco_env.py new file mode 100644 index 000000000..1461e5c15 --- /dev/null +++ b/mujoco_env.py @@ -0,0 +1,103 @@ +# This defines an openai gym environment to walk with the wolfgang robot in mujoco. + +import math +import mujoco +import gymnasium as gym +import cv2 +from tqdm import tqdm +import numpy as np + +import uuid + + +class WolfgangMujocoEnv(gym.Env): + def __init__(self, render_mode = "human"): + # Load the model + self.model = mujoco.MjModel.from_xml_path("/home/florian/Projekt/bitbots/bitbots_main/bitbots_wolfgang/wolfgang_description/urdf/scene.xml") + self.data = mujoco.MjData(self.model) + self.renderer = mujoco.Renderer(self.model, 415, 256) + + self.render_mode = render_mode + + # Extract joint names from the model (they are non unique!) + # We do this by reading a null terminated string at the address of the joint name in the joint names buffer + self.joint_ids_to_names = [self.model.names[self.model.name_jntadr[i]:].decode('utf-8').split('\0')[0] for i in range(self.model.njnt)] + + # Get which values of the observation belong to which joint + self.joint_id_observation_start = self.model.jnt_qposadr + + # Get the actuator names (they are also not unique) + self.actuator_ids_to_names = [self.model.names[self.model.name_actuatoradr[i]:].decode('utf-8').split('\0')[0] for i in range(self.model.nu)] + + + # Define the action space + self.action_space = gym.spaces.Box(low=-math.pi, high=math.pi, shape=(20,)) + # Define the observation space + self.observation_space = gym.spaces.Box(low=-math.pi, high=math.pi, shape=(20,)) + + self.max_length = 2 + + self.name = "WolfgangMujocoEnv-" + str(uuid.uuid4()) + + def reset(self, seed=None): + super().reset(seed=seed) + + # Reset the simulation + mujoco.mj_resetData(self.model, self.data) + return self._get_obs(), {} + + def _get_obs(self): + return self.data.qpos[self.joint_id_observation_start[1]:] + + def step(self, action): + + # Apply the action + for i in range(len(action)): + self.data.ctrl[i] = 0.2 * action[i] + + # Perform an action + for _ in range(10): + mujoco.mj_step(self.model, self.data) + + # Debugging + # Render + self.renderer.update_scene(self.data) + img = self.renderer.render() + cv2.imshow(self.name, cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) + cv2.waitKey(1) + + # Calculate the reward + reward = self.data.xpos[1][0] + + return self._get_obs(), reward, self.data.time >= self.max_length, False, {} + + def render(self): + # Render the simulation + self.renderer.update_scene(self.data) + img = self.renderer.render() + if self.render_mode == "human": + cv2.imshow("Wolfgang", cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) + cv2.waitKey(1) + return img + + def close(self): + super().close() + # Explicitly delete the objects to free memory and get a more deterministic behavior + del self.data + del self.renderer + del self.model + + +if __name__ == "__main__": + env = WolfgangMujocoEnv() + env.reset() + last_frame = -1000 + for _ in tqdm(range(1000)): + print(env.step(env.action_space.sample())) + time = env.data.time + if time - last_frame > 1/60: + last_frame = time + cv2.imshow("Wolfgang", cv2.cvtColor(env.render(), cv2.COLOR_RGB2BGR)) + cv2.waitKey(1) + cv2.destroyAllWindows() + env.close() diff --git a/sb3_mujoco.py b/sb3_mujoco.py new file mode 100644 index 000000000..c1f236860 --- /dev/null +++ b/sb3_mujoco.py @@ -0,0 +1,28 @@ +from mujoco_env import WolfgangMujocoEnv + +import numpy as np + +from stable_baselines3.common.env_util import make_vec_env +from stable_baselines3.common.vec_env import SubprocVecEnv + +from stable_baselines3 import PPO + + +display_env = make_vec_env(WolfgangMujocoEnv, n_envs=1) + +model = PPO("MlpPolicy", display_env, verbose=1, n_steps=512) + +model.learn(20000, log_interval=1) + +display_env = model.get_env() +obs = display_env.reset() + +for i in range(1000): + action, _state = model.predict(obs, deterministic=True) + obs, reward, done, info = display_env.step(action) + if done: + obs = display_env.reset() + if i % 10 == 0: + display_env.render("human") + +display_env.close()