欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

论文实现:处理数据之生成batch

程序员文章站 2022-05-01 12:42:52
...

生成batch代码:

def generate_batch(self, batch_size):
	#根据已有的数据self 生成各种需要的数组
    if self.shuffle:
        shuffled_arg = np.arange(self.length)
        np.random.shuffle(shuffled_arg)
        self.inputs = self.inputs[shuffled_arg]
        self.mask = self.mask[shuffled_arg]
        self.targets = self.targets[shuffled_arg]
    n_batch = int(self.length / batch_size)
    #计算batch长度
    if self.length % batch_size != 0:
        n_batch += 1
    #分割batch
    slices = np.split(np.arange(n_batch * batch_size), n_batch)
    #最后这个batch要重新换成有数据index
    slices[-1] = np.arange(self.length-batch_size, self.length)
    pdb.set_trace()
    return slices

最后一个batch:
可见和上一个batch有很多重复数据,防止最后一组数据越界

 array([1100, 1101, 1102, 1103, 1104, 1105, 1106, 1107, 1108, 1109, 1110,
       1111, 1112, 1113, 1114, 1115, 1116, 1117, 1118, 1119, 1120, 1121,
       1122, 1123, 1124, 1125, 1126, 1127, 1128, 1129, 1130, 1131, 1132,
       1133, 1134, 1135, 1136, 1137, 1138, 1139, 1140, 1141, 1142, 1143,
       1144, 1145, 1146, 1147, 1148, 1149, 1150, 1151, 1152, 1153, 1154,
       1155, 1156, 1157, 1158, 1159, 1160, 1161, 1162, 1163, 1164, 1165,
       1166, 1167, 1168, 1169, 1170, 1171, 1172, 1173, 1174, 1175, 1176,
       1177, 1178, 1179, 1180, 1181, 1182, 1183, 1184, 1185, 1186, 1187,
       1188, 1189, 1190, 1191, 1192, 1193, 1194, 1195, 1196, 1197, 1198,
       1199]), 
       array([1105, 1106, 1107, 1108, 1109, 1110, 1111, 1112, 1113, 1114, 1115,
       1116, 1117, 1118, 1119, 1120, 1121, 1122, 1123, 1124, 1125, 1126,
       1127, 1128, 1129, 1130, 1131, 1132, 1133, 1134, 1135, 1136, 1137,
       1138, 1139, 1140, 1141, 1142, 1143, 1144, 1145, 1146, 1147, 1148,
       1149, 1150, 1151, 1152, 1153, 1154, 1155, 1156, 1157, 1158, 1159,
       1160, 1161, 1162, 1163, 1164, 1165, 1166, 1167, 1168, 1169, 1170,
       1171, 1172, 1173, 1174, 1175, 1176, 1177, 1178, 1179, 1180, 1181,
       1182, 1183, 1184, 1185, 1186, 1187, 1188, 1189, 1190, 1191, 1192,
       1193, 1194, 1195, 1196, 1197, 1198, 1199, 1200, 1201, 1202, 1203,
       1204])]

相关标签: batch