博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Kreas中Sequence的使用样例
阅读量:4142 次
发布时间:2019-05-25

本文共 6683 字,大约阅读时间需要 22 分钟。

训练TF模型时发现电脑内存太小了,无法处理一万张720p的图片(VOC格式数据集),于是改用Sequence进行训练迭代,有效减少内存的要求。代码如下:

from tensorflow.python.keras.utils.data_utils import Sequence# 定义在C:\ProgramData\Anaconda3\envs\tf\Lib\site-packages\tensorflow_core\python\keras\utils\data_utils.pyimport random, os, gc, cv2import numpy as npfrom xml.dom.minidom import parsefrom sklearn.preprocessing import MultiLabelBinarizerfrom sklearn.model_selection import train_test_splitfrom tensorflow.keras.preprocessing.image import img_to_array, load_imgseed = 295random.seed(seed)class SequenceData(Sequence):    '''    xmlPath, imgPath, cutSize=(0,0.7), batch_size=32, size=(720,1280)    xml文件路径, img文件路径, 训练集切分起始比例, 批次大小, 图片大小    '''    def resize_img_keep_ratio(self, img_name,target_size):        '''        1.resize图片,先计算最长边的resize的比例,然后按照该比例resize。        2.计算四个边需要padding的像素宽度,然后padding        '''        img = cv2.imread(img_name)        old_size = img.shape[0:2]        ratio = min(float(target_size[i])/(old_size[i]) for i in range(len(old_size)))        new_size = tuple([int(i*ratio) for i in old_size])        img = cv2.resize(img,(new_size[1], new_size[0]),interpolation=cv2.INTER_CUBIC)  #注意插值算法        pad_w = target_size[1] - new_size[1]        pad_h = target_size[0] - new_size[0]        top,bottom = pad_h//2, pad_h-(pad_h//2)        left,right = pad_w//2, pad_w -(pad_w//2)        img_new = cv2.copyMakeBorder(img,top,bottom,left,right,cv2.BORDER_CONSTANT,None,(0,0,0))        return cv2.cvtColor(img_new, cv2.COLOR_BGR2RGB)    def resize_img(self, img_name,target_size):        img = cv2.imread(img_name)        img = cv2.resize(img,(target_size[1], target_size[0]),interpolation=cv2.INTER_CUBIC)        return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)    def __init__(self, xmlPath, imgPath, cutSize=(0,0.7), batch_size=32, size=(720,1280)):        self.xmlPath = xmlPath        self.imgPath = imgPath        self.batch_size = batch_size        self.cutSize = cutSize        self.datas = os.listdir(xmlPath) # 读入文件列表        random.shuffle(self.datas)        self.L = len(self.datas)        self.datas = self.datas[int(np.ceil(self.L*cutSize[0])):int(np.ceil(self.L*cutSize[1]))] # 切分训练集比例        self.L = len(self.datas)        self.size = size        self.index = random.sample(range(self.L), self.L)        self.names = ['asleep', 'side asleep', 'quilt kicked', 'awake',                        'on stomach', 'crying', 'face covered']  #7个类型    #返回长度,通过len(
<你的实例>
)调用 def __len__(self): return int(np.ceil(len(self.datas) / self.batch_size)) #即通过索引获取a[0],a[1]这种 def __getitem__(self, idx): batch_indexs = self.index[idx:(idx+self.batch_size)] batch_datas = [self.datas[k] for k in batch_indexs] images,labels = self.data_generation(batch_datas) return images,labels def data_generation(self, batch_datas): #预处理操作 data=[] # 分类数据 img=[] # 图片顺序 for file in batch_datas: DOMTree = parse(os.path.join(self.xmlPath,file)) #读取XML文件 imgName = file.replace('xml','jpg') #标签有问题,直接读取文件名 if os.path.exists(os.path.join(self.imgPath,imgName)): #如果图片文件存在,读入并压缩图片 img.append(self.resize_img_keep_ratio(os.path.join(self.imgPath,imgName),(self.size[0],self.size[1]))) name = [] # for obj in DOMTree.documentElement.getElementsByTagName("object"): # 对于每个object标签(在带有多个标签的数据集上会导致loss爆炸) # name.append(obj.getElementsByTagName("name")[0].childNodes[0].data) name.append(DOMTree.documentElement.getElementsByTagName("object")[0].getElementsByTagName("name")[0].childNodes[0].data) # 导入第一个标签 data.append(name) else: img.append(np.zeros((self.size[0],self.size[1],3)).tolist()) data.append([]) return np.asarray(img)/255, np.asarray(MultiLabelBinarizer(classes=self.names).fit_transform(data)) # 多分类独热编码 def get_label(self): #预处理操作 data=[] # 分类数据 for file in [self.datas[k] for k in self.index]: DOMTree = parse(os.path.join(self.xmlPath,file)) #读取XML文件 imgName = file.replace('xml','jpg') #标签有问题,直接读取文件名 if os.path.exists(os.path.join(self.imgPath,imgName)): #如果图片文件存在,读入并压缩图片 name = [] # for obj in DOMTree.documentElement.getElementsByTagName("object"): # 对于每个object标签(在带有多个标签的数据集上会导致loss爆炸) # name.append(obj.getElementsByTagName("name")[0].childNodes[0].data) name.append(DOMTree.documentElement.getElementsByTagName("object")[0].getElementsByTagName("name")[0].childNodes[0].data) # 导入第一个标签 data.append(name) else: data.append([]) return np.asarray(MultiLabelBinarizer(classes=self.names).fit_transform(data)) # 多分类独热编码 def showImg(self,i): from tensorflow.keras.preprocessing.image import array_to_img images,labels = self.data_generation([self.datas[i]]) for ii,xx in enumerate(labels[0]): if xx > 0: print(self.names[ii],end=',') print() array_to_img(images[0]*255.0).show()
使用方法config = {
"batch":2, "epochs":10, "imageResize":(720,1280), "lr":1e-5, "cut_size":(0,0.7,0.85,1),}def trainModelBySequence(xmlPath, imgPath): import DataGenSequence DGS_train = DataGenSequence.SequenceData(xmlPath, imgPath, cutSize=(config["cut_size"][0],config["cut_size"][1]), batch_size=config["batch"], size=config['imageResize']) DGS_val = DataGenSequence.SequenceData(xmlPath, imgPath, cutSize=(config["cut_size"][1],config["cut_size"][2]), batch_size=config["batch"], size=config['imageResize']) DGS_test = DataGenSequence.SequenceData(xmlPath, imgPath, cutSize=(config["cut_size"][2], config["cut_size"][3]), batch_size=config["batch"], size=config['imageResize']) from tensorflow.keras.callbacks import EarlyStopping early_stop = EarlyStopping(monitor='val_loss',patience=config["epochs"]/10,verbose=1,mode='auto') hist = model.fit_generator(generator=DGS_train,steps_per_epoch=int(len(DGS_train)), validation_data=DGS_val,validation_steps=int(len(DGS_val)), workers=20,use_multiprocessing=False,verbose=1, epochs=config["epochs"],callbacks=[early_stop,metrics]) return hist

转载地址:http://crzti.baihongyu.com/

你可能感兴趣的文章
python 变量作用域问题(经典坑)
查看>>
pytorch
查看>>
pytorch(三)
查看>>
ubuntu相关
查看>>
C++ 调用json
查看>>
nano中设置脚本开机自启动
查看>>
动态库调动态库
查看>>
Kubernetes集群搭建之CNI-Flanneld部署篇
查看>>
k8s web终端连接工具
查看>>
手绘VS码绘(一):静态图绘制(码绘使用P5.js)
查看>>
手绘VS码绘(二):动态图绘制(码绘使用Processing)
查看>>
基于P5.js的“绘画系统”
查看>>
《达芬奇的人生密码》观后感
查看>>
论文翻译:《一个包容性设计的具体例子:聋人导向可访问性》
查看>>
基于“分形”编写的交互应用
查看>>
《融入动画技术的交互应用》主题博文推荐
查看>>
链睿和家乐福合作推出下一代零售业隐私保护技术
查看>>
Unifrax宣布新建SiFAB™生产线
查看>>
艾默生纪念谷轮™在空调和制冷领域的百年创新成就
查看>>
NEXO代币持有者获得20,428,359.89美元股息
查看>>