DeepSTdeepstdatasetsSTMatrix.py 代码解析

    科技2024-11-06  22

    from __future__ import print_function import os import pandas as pd import numpy as np from . import load_stdata from ..config import Config from ..utils import string2timestamp //from . import,“.” 代表使用相对路径导入,即从当前项目中寻找需要导入的包或数,from..import绝对导入语句。一个"."表示往上跳一级,假如A包含B和C,要往B里import一个东西,可以写from ..A(两个".",跳的比A高一级了,可) import C. class STMatrix(object): """docstring for STMatrix"""//STMatrix的字符串文本 def __init__(self, data, timestamps, T=48, CheckComplete=True)://定义构造函数 super(STMatrix, self).__init__()//#super表继承,这里继承自己 assert len(data) == len(timestamps)//assert:断言 前置条件断言:代码执行之前必须具备的特性,如果不满足程序就会中断 self.data = data self.timestamps = timestamps self.T = T self.pd_timestamps = string2timestamp(timestamps, T=self.T)//字符转换成时间戳,timestamp = time.time(),为float型,时间戳是计算机能够识别的时间;时间字符串是人能够看懂的时间;元组则是用来操作时间的。 if CheckComplete: self.check_complete()//转成每半小时每半小时 # index self.make_index()//以每半小时的时间做索引 def make_index(self)://可能是做一个索引 self.get_index = dict()//字典 包括索引和对应的值 for i, ts in enumerate(self.pd_timestamps)://enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中。 self.get_index[ts] = i def check_complete(self): missing_timestamps = [] offset = pd.DateOffset(minutes=24 * 60 // self.T) //T=48,"//"表示取整除 - 返回商的整数部分(向下取整)。DateOffset可按指定的日历日时间段偏移日期时间,可能是把时间归成每半个小时一个。 pd_timestamps = self.pd_timestamps i = 1 while i < len(pd_timestamps): if pd_timestamps[i-1] + offset != pd_timestamps[i]: missing_timestamps.append("(%s -- %s)" % (pd_timestamps[i-1], pd_timestamps[i])) i += 1 for v in missing_timestamps: print(v) assert len(missing_timestamps) == 0 def get_matrix(self, timestamp)://获得timestamp为索引的data return self.data[self.get_index[timestamp]] def save(self, fname): pass //Python pass 是空语句,是为了保持程序结构的完整性。pass 不做任何事情,一般用做占位语句。该处的 pass 便是占据一个位置,因为如果定义一个空函数程序会报错,当你没有想好函数的内容是可以用 pass 填充,使程序可以正常运行。 def check_it(self, depends): //检查关键字,应该要是对的 for d in depends: if d not in self.get_index.keys(): return False return True def create_dataset(self, len_closeness=3, len_trend=3, TrendInterval=7, len_period=3, PeriodInterval=1): //这个函数没细看(太难了),大概讲的是构造出临近性、趋势性、周期性的三种数据形式 """current version """ # offset_week = pd.DateOffset(days=7) offset_frame = pd.DateOffset(minutes=24 * 60 // self.T) XC = [] XP = [] XT = [] Y = [] timestamps_Y = [] depends = [range(1, len_closeness+1), [PeriodInterval * self.T * j for j in range(1, len_period+1)], [TrendInterval * self.T * j for j in range(1, len_trend+1)]] i = max(self.T * TrendInterval * len_trend, self.T * PeriodInterval * len_period, len_closeness) while i < len(self.pd_timestamps): Flag = True for depend in depends: if Flag is False: break Flag = self.check_it([self.pd_timestamps[i] - j * offset_frame for j in depend]) if Flag is False: i += 1 continue x_c = [self.get_matrix(self.pd_timestamps[i] - j * offset_frame) for j in depends[0]] x_p = [self.get_matrix(self.pd_timestamps[i] - j * offset_frame) for j in depends[1]] x_t = [self.get_matrix(self.pd_timestamps[i] - j * offset_frame) for j in depends[2]] y = self.get_matrix(self.pd_timestamps[i]) if len_closeness > 0: XC.append(np.vstack(x_c)) if len_period > 0: XP.append(np.vstack(x_p)) if len_trend > 0: XT.append(np.vstack(x_t)) Y.append(y) timestamps_Y.append(self.timestamps[i]) i += 1 XC = np.asarray(XC) XP = np.asarray(XP) XT = np.asarray(XT) Y = np.asarray(Y) print("XC shape: ", XC.shape, "XP shape: ", XP.shape, "XT shape: ", XT.shape, "Y shape:", Y.shape) return XC, XP, XT, Y, timestamps_Y if __name__ == '__main__': pass //当**.py**文件被直接运行时,if __name__ ==’__main__'之下的代码块将被运行;当.py文件以模块形式被导入时,if __name__ == '__main__'之下的代码块不被运行。

    总代码链接https://github.com/amirkhango/DeepST

    Processed: 0.011, SQL: 8