results.py

这段代码定义了一个名为 Results 的类,用于存储、更新和保存实验数据。以下是代码的详细解释:


1. 导入模块

  • pandas: 用于处理表格数据(DataFrame)。

  • numpy: 用于数值计算和数组操作。


2. Results

这是一个用于管理实验结果的类,支持数据的更新和保存。

构造函数 (__init__)

python
复制
def __init__(self, col_names):
    self.col_names = col_names
    self.df = pd.DataFrame()
  • 参数:

    • col_names: 列名的列表,用于定义结果数据的列。

  • 功能:

    • 初始化一个空的 DataFrame (self.df),列名由 col_names 指定。


update_result 方法

python
复制
def update_result(self, data):
    assert data.shape[-1] + 1 == len(self.col_names),\
        f"The shape of the input array doesn't match the number of columns in the results dataframe. {data.shape[-1]+1} columns were expected vs {len(self.col_names)}."
    
    self.df= pd.DataFrame(columns=self.col_names)

    for episode in range(data.shape[0]):
        # add the episode number to the data
        episode_data = np.concatenate((data[episode], np.full(shape=(data.shape[1], 1), fill_value=episode)), axis=1)
        self.df = self.df._append(pd.DataFrame(data=episode_data, columns=self.col_names), ignore_index=True)
  • 参数:

    • data: 一个多维数组(通常是 3D 数组),包含多个 episode 的实验数据。

  • 功能:

    1. 检查输入数据的形状:

      • 使用 assert 检查输入数据的最后一维大小是否与列名数量匹配(data.shape[-1] + 1 == len(self.col_names))。

      • 如果不匹配,抛出错误信息。

    2. 清空 DataFrame:

      • 每次调用 update_result 时,清空现有的 DataFrame,并重新初始化列名。

    3. 更新数据:

      • 遍历每个 episode 的数据。

      • 将 episode 编号添加到数据中(np.full(shape=(data.shape[1], 1), fill_value=episode))。

      • 将 episode 数据转换为 DataFrame,并追加到 self.df 中。


save 方法

python
复制
def save(self, filename):
    self.df.to_csv(filename, index=False)
  • 参数:

    • filename: 保存结果的文件名(包括路径)。

  • 功能:

    • self.df 保存为 CSV 文件,不包含行索引。


3. 代码的核心功能

  • 数据管理:

    • 支持将多维实验数据(如多个 episode 的观测值、动作、奖励等)存储到 DataFrame 中。

    • 每次更新数据时,清空现有数据并重新填充。

  • 数据保存:

    • 将结果保存为 CSV 文件,方便后续分析和可视化。


4. 代码的应用场景

  • 强化学习实验:

    • 存储每个 episode 的观测值、动作、奖励等数据。

    • 保存实验结果,用于后续分析。

  • 数据记录:

    • 记录实验过程中的关键指标(如损失、准确率等)。


5. 示例

假设我们有以下实验数据:

python
复制
data = np.array([
    [[1, 2, 3], [4, 5, 6]],  # Episode 1
    [[7, 8, 9], [10, 11, 12]]  # Episode 2
])
col_names = ["obs1", "obs2", "obs3", "episode"]

使用 Results 类:

python
复制
# 初始化 Results 对象
results = Results(col_names)

# 更新数据
results.update_result(data)

# 保存结果
results.save("results.csv")

results.csv 文件内容:

复制
obs1,obs2,obs3,episode
1.0,2.0,3.0,0.0
4.0,5.0,6.0,0.0
7.0,8.0,9.0,1.0
10.0,11.0,12.0,1.0

6. 代码的改进建议

  1. 不清空 DataFrame:

    • 当前实现中,每次调用 update_result 都会清空 DataFrame。如果希望累积数据,可以移除 self.df = pd.DataFrame(columns=self.col_names) 这一行。

  2. 性能优化:

    • 使用 pd.concat 代替 _append,因为 _append 已被弃用,且 pd.concat 性能更好。

  3. 异常处理:

    • 添加对输入数据形状的异常处理,避免程序崩溃。


总结

这段代码实现了一个简单的 Results 类,用于存储、更新和保存实验数据。它适用于强化学习实验中的数据记录和管理,但需要注意每次更新数据时会清空现有数据。可以通过改进实现来支持数据累积和性能优化。

暂无评论

发送评论 编辑评论


				
|´・ω・)ノ
ヾ(≧∇≦*)ゝ
(☆ω☆)
(╯‵□′)╯︵┴─┴
 ̄﹃ ̄
(/ω\)
∠( ᐛ 」∠)_
(๑•̀ㅁ•́ฅ)
→_→
୧(๑•̀⌄•́๑)૭
٩(ˊᗜˋ*)و
(ノ°ο°)ノ
(´இ皿இ`)
⌇●﹏●⌇
(ฅ´ω`ฅ)
(╯°A°)╯︵○○○
φ( ̄∇ ̄o)
ヾ(´・ ・`。)ノ"
( ง ᵒ̌皿ᵒ̌)ง⁼³₌₃
(ó﹏ò。)
Σ(っ °Д °;)っ
( ,,´・ω・)ノ"(´っω・`。)
╮(╯▽╰)╭
o(*////▽////*)q
>﹏<
( ๑´•ω•) "(ㆆᴗㆆ)
😂
😀
😅
😊
🙂
🙃
😌
😍
😘
😜
😝
😏
😒
🙄
😳
😡
😔
😫
😱
😭
💩
👻
🙌
🖕
👍
👫
👬
👭
🌚
🌝
🙈
💊
😶
🙏
🍦
🍉
😣
Source: github.com/k4yt3x/flowerhd
颜文字
Emoji
小恐龙
花!
上一篇
下一篇