train_agent.py

这段代码是一个用于训练强化学习(RL)代理的脚本,专门针对名为 GreenLightHeatCO2 的环境。它使用 Stable-Baselines3 库中的 PPO(Proximal Policy Optimization)算法进行训练,并将训练结果记录到 Weights and Biases(wandb)平台。以下是代码的详细解释:


1. 脚本概述

  • 目的:训练一个 RL 代理,并在训练过程中定期评估代理的性能。训练结果会被记录到 wandb,并且最佳模型会被保存。

  • 功能

    • 创建训练和评估环境。

    • 初始化 PPO 模型。

    • 训练模型并定期评估。

    • 保存最佳模型和环境状态。

    • 支持从已保存的模型继续训练。


2. 主要模块和函数

(1) runExperiment 函数

这是脚本的核心函数,负责以下任务:

  • 初始化 wandb:用于记录训练过程中的指标和超参数。

  • 创建训练和评估环境

    • 使用 make_vec_env 函数创建并行化的训练环境和评估环境。

    • 环境会被封装在 VecNormalizeVecMonitor 中,以支持状态归一化和监控。

  • 初始化 PPO 模型

    • 如果 continue_trainingTrue,则加载已保存的模型和环境继续训练。

    • 否则,从头开始初始化 PPO 模型。

  • 训练模型

    • 使用 model.learn 方法训练模型。

    • 在训练过程中,定期调用回调函数进行评估,并保存最佳模型。

  • 保存模型和环境

    • 训练结束后,保存最终的模型和环境状态。

  • 清理资源

    • 关闭环境并释放内存。

(2) make_vec_env 函数

  • 用于创建并行化的环境。

  • 支持训练环境和评估环境的创建。

  • 环境会被封装在 VecNormalizeVecMonitor 中,以支持状态归一化和监控。

(3) create_callbacks 函数

  • 创建训练过程中的回调函数。

  • 主要用于定期评估模型性能,并保存最佳模型和环境状态。

(4) wandb_init 函数

  • 初始化 wandb 运行,用于记录训练过程中的指标和超参数。

(5) load_env_paramsload_model_params 函数

  • 从配置文件中加载环境和模型的超参数。


3. 命令行参数

脚本支持通过命令行参数自定义训练过程。以下是一些重要的参数:

  • --env_id:指定环境 ID(例如 GreenLightHeatCO2)。

  • --project:wandb 项目名称。

  • --group:wandb 组名称。

  • --total_timesteps:训练的总时间步数。

  • --n_eval_episodes:每次评估时运行的 episode 数量。

  • --num_cpus:用于训练的并行环境数量。

  • --n_evals:训练过程中评估的次数。

  • --continue_training:是否从已保存的模型继续训练。

  • --continued_project--continued_runname:指定要加载的模型所属的 wandb 项目和运行名称。


4. 代码流程

  1. 解析命令行参数

    • 使用 argparse 解析用户输入的命令行参数。

  2. 加载环境和模型参数

    • 从配置文件中加载环境和模型的超参数。

  3. 调用 runExperiment 函数

    • 初始化 wandb。

    • 创建训练和评估环境。

    • 初始化或加载 PPO 模型。

    • 训练模型并定期评估。

    • 保存最佳模型和环境状态。

  4. 清理资源

    • 关闭环境并释放内存。


5. 关键点

  • 并行化训练:通过 SubprocVecEnvnum_cpus 参数支持多进程并行训练。

  • 状态归一化:使用 VecNormalize 对观察值和奖励进行归一化。

  • 定期评估:通过回调函数定期评估模型性能,并保存最佳模型。

  • wandb 集成:训练过程中的指标和超参数会被记录到 wandb,方便跟踪和可视化。

  • 继续训练:支持从已保存的模型继续训练。


6. 示例用法

以下是一个示例命令行,用于运行脚本:

bash
复制
python -m greenlight.experiments.train_agent \
    --env_id GreenLightHeatCO2 \
    --project my_project \
    --group my_group \
    --total_timesteps 1000000 \
    --n_eval_episodes 5 \
    --num_cpus 8 \
    --n_evals 20 \
    --save_model \
    --save_env

7. 总结

这段代码是一个完整的 RL 训练脚本,适用于 GreenLightHeatCO2 环境。它结合了 Stable-Baselines3 和 wandb,提供了高效的训练、评估和模型保存功能。通过命令行参数,用户可以灵活地配置训练过程。

暂无评论

发送评论 编辑评论


				
|´・ω・)ノ
ヾ(≧∇≦*)ゝ
(☆ω☆)
(╯‵□′)╯︵┴─┴
 ̄﹃ ̄
(/ω\)
∠( ᐛ 」∠)_
(๑•̀ㅁ•́ฅ)
→_→
୧(๑•̀⌄•́๑)૭
٩(ˊᗜˋ*)و
(ノ°ο°)ノ
(´இ皿இ`)
⌇●﹏●⌇
(ฅ´ω`ฅ)
(╯°A°)╯︵○○○
φ( ̄∇ ̄o)
ヾ(´・ ・`。)ノ"
( ง ᵒ̌皿ᵒ̌)ง⁼³₌₃
(ó﹏ò。)
Σ(っ °Д °;)っ
( ,,´・ω・)ノ"(´っω・`。)
╮(╯▽╰)╭
o(*////▽////*)q
>﹏<
( ๑´•ω•) "(ㆆᴗㆆ)
😂
😀
😅
😊
🙂
🙃
😌
😍
😘
😜
😝
😏
😒
🙄
😳
😡
😔
😫
😱
😭
💩
👻
🙌
🖕
👍
👫
👬
👭
🌚
🌝
🙈
💊
😶
🙏
🍦
🍉
😣
Source: github.com/k4yt3x/flowerhd
颜文字
Emoji
小恐龙
花!
上一篇
下一篇