这段代码定义了一个名为 Results
的类,用于存储、更新和保存实验数据。以下是代码的详细解释:
1. 导入模块
-
pandas
: 用于处理表格数据(DataFrame)。 -
numpy
: 用于数值计算和数组操作。
2. Results
类
这是一个用于管理实验结果的类,支持数据的更新和保存。
构造函数 (__init__
)
def __init__(self, col_names): self.col_names = col_names self.df = pd.DataFrame()
-
参数:
-
col_names
: 列名的列表,用于定义结果数据的列。
-
-
功能:
-
初始化一个空的 DataFrame (
self.df
),列名由col_names
指定。
-
update_result
方法
def update_result(self, data): assert data.shape[-1] + 1 == len(self.col_names),\ f"The shape of the input array doesn't match the number of columns in the results dataframe. {data.shape[-1]+1} columns were expected vs {len(self.col_names)}." self.df= pd.DataFrame(columns=self.col_names) for episode in range(data.shape[0]): # add the episode number to the data episode_data = np.concatenate((data[episode], np.full(shape=(data.shape[1], 1), fill_value=episode)), axis=1) self.df = self.df._append(pd.DataFrame(data=episode_data, columns=self.col_names), ignore_index=True)
-
参数:
-
data
: 一个多维数组(通常是 3D 数组),包含多个 episode 的实验数据。
-
-
功能:
-
检查输入数据的形状:
-
使用
assert
检查输入数据的最后一维大小是否与列名数量匹配(data.shape[-1] + 1 == len(self.col_names)
)。 -
如果不匹配,抛出错误信息。
-
-
清空 DataFrame:
-
每次调用
update_result
时,清空现有的 DataFrame,并重新初始化列名。
-
-
更新数据:
-
遍历每个 episode 的数据。
-
将 episode 编号添加到数据中(
np.full(shape=(data.shape[1], 1), fill_value=episode)
)。 -
将 episode 数据转换为 DataFrame,并追加到
self.df
中。
-
-
save
方法
def save(self, filename): self.df.to_csv(filename, index=False)
-
参数:
-
filename
: 保存结果的文件名(包括路径)。
-
-
功能:
-
将
self.df
保存为 CSV 文件,不包含行索引。
-
3. 代码的核心功能
-
数据管理:
-
支持将多维实验数据(如多个 episode 的观测值、动作、奖励等)存储到 DataFrame 中。
-
每次更新数据时,清空现有数据并重新填充。
-
-
数据保存:
-
将结果保存为 CSV 文件,方便后续分析和可视化。
-
4. 代码的应用场景
-
强化学习实验:
-
存储每个 episode 的观测值、动作、奖励等数据。
-
保存实验结果,用于后续分析。
-
-
数据记录:
-
记录实验过程中的关键指标(如损失、准确率等)。
-
5. 示例
假设我们有以下实验数据:
data = np.array([ [[1, 2, 3], [4, 5, 6]], # Episode 1 [[7, 8, 9], [10, 11, 12]] # Episode 2 ]) col_names = ["obs1", "obs2", "obs3", "episode"]
使用 Results
类:
# 初始化 Results 对象 results = Results(col_names) # 更新数据 results.update_result(data) # 保存结果 results.save("results.csv")
results.csv
文件内容:
obs1,obs2,obs3,episode 1.0,2.0,3.0,0.0 4.0,5.0,6.0,0.0 7.0,8.0,9.0,1.0 10.0,11.0,12.0,1.0
6. 代码的改进建议
-
不清空 DataFrame:
-
当前实现中,每次调用
update_result
都会清空 DataFrame。如果希望累积数据,可以移除self.df = pd.DataFrame(columns=self.col_names)
这一行。
-
-
性能优化:
-
使用
pd.concat
代替_append
,因为_append
已被弃用,且pd.concat
性能更好。
-
-
异常处理:
-
添加对输入数据形状的异常处理,避免程序崩溃。
-
总结
这段代码实现了一个简单的 Results
类,用于存储、更新和保存实验数据。它适用于强化学习实验中的数据记录和管理,但需要注意每次更新数据时会清空现有数据。可以通过改进实现来支持数据累积和性能优化。