python lmdb使用

it2025-11-09  13

python lmdb使用

python lmdb使用安装项目实际用例pytorch dataload方法从写

python lmdb使用

LMDB和SQLite/MySQL等关系型数据库不同,属于key-value数据库(把LMDB想成dict会比较容易理解),键key与值value都是字符串。

安装

pip install lmdb

###操作流程 1.创建lmdb环境 env = lmdb.open() 2.建立事务 txn = env.begin() 3.向事务中写入或者修改数据 txn.put(key, value) 4. 删除数据 txn.delete(key) 5. 数据查询 txn.get(key) 6. 数据遍历 txn.cursor() 7. 数据提交 txn.commit()

项目实际用例

在进行OCR文本识别的过程中训练的数据量较大,所以采用将数据保存为LMDB数据,指定图片的主路径与标签文件

如下所示:

# coding:utf-8 import os import lmdb # install lmdb by "pip install lmdb" import cv2 import re from PIL import Image import numpy as np import imghdr import argparse from tqdm import tqdm def init_args(): args = argparse.ArgumentParser() args.add_argument('-i', '--image_dir', default='', type=str, help='The directory of the dataset , which contains the images') args.add_argument('-l', '--label_file', default='/datassd/hzl/text_render_data/mingpian/new_gray.txt', type=str, help='The file which contains the paths and the labels of the data set') args.add_argument('-s', '--save_dir', default='/datassd/hzl/text_render_data/mingpian/lmdb_gray/', type=str , help='The generated mdb file save dir') args.add_argument('-m', '--map_size', help='map size of lmdb', type=int, default=40000000000000) return args.parse_args() def checkImageIsValid(imageBin): if imageBin is None: return False try: imageBuf = np.frombuffer(imageBin, dtype=np.uint8) img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE) imgH, imgW = img.shape[0], img.shape[1] except: return False else: if imgH * imgW == 0: return False return True def writeCache(env, cache): with env.begin(write=True) as txn: # 建立事务 for k, v in cache.items(): if type(k) == str: k = k.encode() if type(v) == str: v = v.encode() txn.put(k,v) #写入数据 def createDataset(outputPath, imagePathList, labelList, map_size, lexiconList=None, checkValid=True): """ Create LMDB dataset for CRNN training. ARGS: outputPath : LMDB output path imagePathList : list of image path labelList : list of corresponding groundtruth texts lexiconList : (optional) list of lexicon lists checkValid : if true, check the validity of every image """ assert (len(imagePathList) == len(labelList)) nSamples = len(imagePathList) env = lmdb.open(outputPath, map_size=map_size) #创建lmdb环境 # env = lmdb.open(outputPath) cache = {} cnt = 0 for i in tqdm(range(nSamples)): imagePath = imagePathList[i].replace('\n', '').replace('\r\n', '') label = labelList[i] with open(imagePath, 'rb') as f: imageBin = f.read() if checkValid: if not checkImageIsValid(imageBin): print('%s is not a valid image' % imagePath) continue imageKey = 'image-%09d' % cnt labelKey = 'label-%09d' % cnt cache[imageKey] = imageBin cache[labelKey] = label if lexiconList: lexiconKey = 'lexicon-%09d' % cnt cache[lexiconKey] = ' '.join(lexiconList[i]) if cnt != 0 and cnt % 1000 == 0: writeCache(env, cache) # 写入数据 cache = {} print('Written %d / %d' % (cnt, nSamples)) cnt += 1 cache['num-samples'] = str(nSamples) writeCache(env, cache) env.close() # 关闭事务 print('Created dataset with %d samples' % nSamples) if __name__ == '__main__': args = init_args() # 初始化参数 imgdata = open(args.label_file, mode='r') lines = list(imgdata) #获取标签列表 imgDir = args.image_dir imgPathList = [] labelList = [] # 将标签中的文件读取,并过滤不正常的数据,保存为图片和标签的list for i, line in enumerate(lines): #imgPath = os.path.join(imgDir, line.split()[0].decode('utf-8')) #print(line.strip().split('\t')) if line.strip() == '': continue if ' ' in line: imgPath, word = line.strip('\n').strip().split('\t') else: imgPath, word = line.strip('\n').strip().split() if not os.path.exists(imgPath): continue imgPathList.append(imgPath) labelList.append(word) # 写入lmdb createDataset(args.save_dir, imgPathList, labelList, args.map_size)

pytorch dataload方法从写

import lmdb import six import sys from PIL import Image import cv2 import numpy as np from lib.dataset.transformers import * class LMDB(Dataset): def __init__(self, root=None, is_train=True): if is_train: root = config.DATASET.TRAIN_FILE # 如果训练的话加载训练的lmdb文件,如果验证的话加载测试集lmdb文件 else: root = config.DATASET.TEST_FILE root = config.LMDB_ROOT trainsform = config.DATASET.TRANSFORM target_transform = config.DATASET.TARGET_TRANSFORM self.env = lmdb.open( root, max_readers=1, readonly=True, lock=False, readahead=False, meminit=False, map_size=40000000000000) if not self.env: print('cannot creat lmdb from %s' % (root)) sys.exit(0) with self.env.begin(write=False) as txn: str = 'num-samples'.encode('utf-8') nSamples = int(txn.get(str)) self.nSamples = nSamples self.transform = transform self.target_transform = target_transform def __len__(self): return self.nSamples def __getitem__(self, index): assert index <= len(self), 'index range error' index += 1 with self.env.begin(write=False) as txn: img_key = 'image-%09d' % index imgbuf = txn.get(img_key.encode('utf-8')) try: buf = six.BytesIO() buf.write(imgbuf) buf.seek(0) img = Image.open(buf) if self.target_transform == None: img = cv2.cvtColor(np.asarray(img),cv2.COLOR_RGB2BGR) # 转opencv增强 img = random_transformers(img) # 随机增强 img = Image.fromarray(cv2.cvtColor(img,cv2.COLOR_BGR2RGB)) # 转回PIL img = np.array(img) # img = generate_image(img) # img = Image.fromarray(np.uint8(img)).convert('L') img = Image.fromarray(np.uint8(img)) except: # traceback.print_exc() # print('Corrupted image for %d' % index) # return self[index + 1] return self[1] if self.transform is not None: img = self.transform(img) label_key = 'label-%09d' % index label = txn.get(label_key.encode()) # if len(set(label.decode('utf-8')) - set(alphabets.alphabet)) != 0: label = label.decode('utf-8').replace(' ', '').replace('¥', '¥').encode('utf-8') # print(label) # return self[1] if self.target_transform is not None: label = self.target_transform(label) return img, label
最新回复(0)