JORDAN CAMPBELL
R&D SOFTWARE ENGINEER
cybernaut/0.1.4

Basic reinforcement learning setup using gymnasium and stable baselines.

I'm doing all this work on an M1 (macos), so some things may need adapting for other machines / OS.

Start by creating a new directory and environment.

mkdir rl
cd rl

python3 -m venv rl-env
source rl-env/bin/activate

brew install cmake openmpi  # This is recommended by the Gymnasium team

pip install stable-baselines3[extra] gymnasium[box2d]

If you're using zsh as your shell then you'll need to escape the brackets: gymnasium\[box2d\]

You may also get some errors about needing to install swig (whatever that is). I found that doing brew install swig fixed this, whereas pip install swig didn't work.

A good starting point is the LunarLander environment. This is complex enough to be interesting, but also basically trivial to train and get working within a minute or so. The code here is taken from the gymnasium documentation, but I've added the command line args to enable training / loading. This makes it a bit easier to run multiple times and play around while you're getting started.

# main.py

import gymnasium as gym
import sys
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy

# e.g. 'python3 main.py -train -verbose'
config = {
    'training_mode': '-train' in sys.argv,
    'verbose': '-verbose' in sys.argv,
}

# start with 1000, you can change this as you wish.
training_steps = 1000

env = gym.make("LunarLander-v3", render_mode="rgb_array")

if cfg['training_mode']:
    model = PPO("MlpPolicy", env, verbose=cfg['verbose'])
    model.learn(total_timesteps=int(training_steps), progress_bar=True)
    model.save("ppo_lunar")
else:
    model = PPO.load("ppo_lunar", env=env)

mean_reward, std_reward = evaluate_policy(model, model.get_env(), n_eval_episodes=10)

vec_env = model.get_env()
obs = vec_env.reset()
for i in range(1000):
    action, _states = model.predict(obs, deterministic=True)
    obs, rewards, dones, info = vec_env.step(action)
    vec_env.render("human")