深度學(xué)習(xí)進階筆記之八 TensorFlow與中文手寫漢字識別.docx_第1頁
深度學(xué)習(xí)進階筆記之八 TensorFlow與中文手寫漢字識別.docx_第2頁
深度學(xué)習(xí)進階筆記之八 TensorFlow與中文手寫漢字識別.docx_第3頁
深度學(xué)習(xí)進階筆記之八 TensorFlow與中文手寫漢字識別.docx_第4頁
深度學(xué)習(xí)進階筆記之八 TensorFlow與中文手寫漢字識別.docx_第5頁
已閱讀5頁,還剩16頁未讀, 繼續(xù)免費閱讀

下載本文檔

版權(quán)說明:本文檔由用戶提供并上傳,收益歸屬內(nèi)容提供方,若內(nèi)容存在侵權(quán),請進行舉報或認領(lǐng)

文檔簡介

UCLoud中國云三強: 深度學(xué)習(xí)進階筆記之八 | TensorFlow與中文手寫漢字識別引言TensorFlow是Google基于DistBelief進行研發(fā)的第二代人工智能學(xué)習(xí)系統(tǒng),被廣泛用于語音識別或圖像識別等多項機器深度學(xué)習(xí)領(lǐng)域。其命名來源于本身的運行原理。Tensor(張量)意味著N維數(shù)組,F(xiàn)low(流)意味著基于數(shù)據(jù)流圖的計算,TensorFlow代表著張量從圖象的一端流動到另一端計算過程,是將復(fù)雜的數(shù)據(jù)結(jié)構(gòu)傳輸至人工智能神經(jīng)網(wǎng)中進行分析和處理的過程。TensorFlow完全開源,任何人都可以使用??稍谛〉揭徊恐悄苁謾C、大到數(shù)千臺數(shù)據(jù)中心服務(wù)器的各種設(shè)備上運行。機器學(xué)習(xí)進階筆記系列將深入解析TensorFlow系統(tǒng)的技術(shù)實踐,從零開始,由淺入深,與大家一起走上機器學(xué)習(xí)的進階之路。Goal本文目標(biāo)是利用TensorFlow做一個簡單的圖像分類器,在比較大的數(shù)據(jù)集上,盡可能高效地做圖像相關(guān)處理,從Train,Validation到Inference,是一個比較基本的Example, 從一個基本的任務(wù)學(xué)習(xí)如果在TensorFlow下做高效地圖像讀取,基本的圖像處理,整個項目很簡單,但其中有一些trick,在實際項目當(dāng)中有很大的好處, 比如堅決不要一次讀入所有的 的數(shù)據(jù)到內(nèi)存(盡管在Mnist這類級別的例子上經(jīng)常出現(xiàn))剛開始看到是這篇blog里面的TensorFlow練習(xí)22: 手寫漢字識別, 但是這篇文章只用了140訓(xùn)練與測試,試了下代碼 很快,但是當(dāng)擴展到所有的時,發(fā)現(xiàn)32g的內(nèi)存都不夠用,這才注意到原文中都是用numpy,會先把所有的數(shù)據(jù)放入到內(nèi)存,但這個不必須的,無論在MXNet還是TensorFlow中都是不必 須的,MXNet使用的是DataIter,會在程序運行的過程中異步讀取數(shù)據(jù),TensorFlow也是這樣的,TensorFlow封裝了高級的api,用來做數(shù)據(jù)的讀取,比如TFRecord,還有就是從filenames中讀取, 來異步讀取文件,然后做shuffle batch,再feed到模型的Graph中來做模型參數(shù)的更新。具體在tf如何做數(shù)據(jù)的讀取可以看看reading data in tensorflow這里我會拿到所有的數(shù)據(jù)集來做訓(xùn)練與測試,算作是對斗大的熊貓上面那篇文章的一個擴展。Batch Generate數(shù)據(jù)集來自于中科院自動化研究所,感謝分享精神!具體下載:wget /databases/download/feature_data/HWDB1.1trn_gnt.zipwget /databases/download/feature_data/HWDB1.1tst_gnt.zip解壓后發(fā)現(xiàn)是一些gnt文件,然后用了斗大的熊貓里面的代碼,將所有文件都轉(zhuǎn)化為對應(yīng)label目錄下的所有png的圖片。(注意在HWDB1.1trn_gnt.zip解壓后是alz文件,需要再次解壓 我在mac沒有找到合適的工具,windows上有alz的解壓工具)。import osimport numpy as npimport structfrom PIL import Imagedata_dir = ./datatrain_data_dir = os.path.join(data_dir, HWDB1.1trn_gnt)test_data_dir = os.path.join(data_dir, HWDB1.1tst_gnt)def read_from_gnt_dir(gnt_dir=train_data_dir): def one_file(f): header_size = 10 while True: header = np.fromfile(f, dtype=uint8, count=header_size) if not header.size: break sample_size = header0 + (header18) + (header216) + (header324) tagcode = header5 + (header48) width = header6 + (header78) height = header8 + (header9H, tagcode).decode(gb2312) char_set.add(tagcode_unicode)char_list = list(char_set)char_dict = dict(zip(sorted(char_list), range(len(char_list)print len(char_dict)import picklef = open(char_dict, wb)pickle.dump(char_dict, f)f.close()train_counter = 0test_counter = 0for image, tagcode in read_from_gnt_dir(gnt_dir=train_data_dir): tagcode_unicode = struct.pack(H, tagcode).decode(gb2312) im = Image.fromarray(image) dir_name = ./data/train/ + %0.5d%char_dicttagcode_unicode if not os.path.exists(dir_name): os.mkdir(dir_name) im.convert(RGB).save(dir_name+/ + str(train_counter) + .png) train_counter += 1for image, tagcode in read_from_gnt_dir(gnt_dir=test_data_dir): tagcode_unicode = struct.pack(H, tagcode).decode(gb2312) im = Image.fromarray(image) dir_name = ./data/test/ + %0.5d%char_dicttagcode_unicode if not os.path.exists(dir_name): os.mkdir(dir_name) im.convert(RGB).save(dir_name+/ + str(test_counter) + .png) test_counter += 1處理好的數(shù)據(jù),放到了云盤,大家可以直接在我的云盤來下載處理好的數(shù)據(jù)集HWDB1. 這里說明下,char_dict是漢字和對應(yīng)的數(shù)字label的記錄。得到數(shù)據(jù)集后,就要考慮如何讀取了,一次用numpy讀入內(nèi)存在很多小數(shù)據(jù)集上是可以行的,但是在稍微大點的數(shù)據(jù)集上內(nèi)存就成了瓶頸,但是不要害怕,TensorFlow有自己的方法:def batch_data(file_labels,sess, batch_size=128): image_list = file_label0 for file_label in file_labels label_list = int(file_label1) for file_label in file_labels print tag2 0.format(len(image_list) images_tensor = tf.convert_to_tensor(image_list, dtype=tf.string) labels_tensor = tf.convert_to_tensor(label_list, dtype=64) input_queue = tf.train.slice_input_producer(images_tensor, labels_tensor) labels = input_queue1 images_content = tf.read_file(input_queue0) # images = tf.image.decode_png(images_content, channels=1) images = tf.image.convert_image_dtype(tf.image.decode_png(images_content, channels=1), tf.float32) # images = images / 256 images = pre_process(images) # print images.get_shape() # one hot labels = tf.one_hot(labels, 3755) image_batch, label_batch = tf.train.shuffle_batch(images, labels, batch_size=batch_size, capacity=50000,min_after_dequeue=10000) # print image_batch, image_batch.get_shape() coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) return image_batch, label_batch, coord, threads簡單介紹下,首先你需要得到所有的圖像的path和對應(yīng)的label的列表,利用tf.convert_to_tensor轉(zhuǎn)換為對應(yīng)的tensor, 利用tf.train.slice_input_producer將image_list ,label_list做一個slice處理,然后做圖像的讀取、預(yù)處理,以及l(fā)abel的one_hot表示,然后就是傳到tf.train.shuffle_batch產(chǎn)生一個個shuffle batch,這些就可以feed到你的 模型。 slice_input_producer和shuffle_batch這類操作內(nèi)部都是基于queue,是一種異步的處理方式,會在設(shè)備中開辟一段空間用作cache,不同的進程會分別一直往cache中塞數(shù)據(jù) 和取數(shù)據(jù),保證內(nèi)存或顯存的占用以及每一個mini-batch不需要等待,直接可以從cache中獲取。Data Augmentation由于圖像場景不復(fù)雜,只是做了一些基本的處理,包括圖像翻轉(zhuǎn),改變下亮度等等,這些在TensorFlow里面有現(xiàn)成的api,所以盡量使用TensorFlow來做相關(guān)的處理:def pre_process(images): if FLAGS.random_flip_up_down: images = tf.image.random_flip_up_down(images) if FLAGS.random_flip_left_right: images = tf.image.random_flip_left_right(images) if FLAGS.random_brightness: images = tf.image.random_brightness(images, max_delta=0.3) if FLAGS.random_contrast: images = tf.image.random_contrast(images, 0.8, 1.2) new_size = tf.constant(FLAGS.image_size,FLAGS.image_size, dtype=32) images = tf.image.resize_images(images, new_size) return imagesBuild Graph這里很簡單的構(gòu)造了一個兩個卷積+一個全連接層的網(wǎng)絡(luò),沒有做什么更深的設(shè)計,感覺意義不大,設(shè)計了一個dict,用來返回后面要用的所有op,還有就是為了方便再訓(xùn)練中查看loss和accuracy, 沒有什么特別的,很容易理解, labels 為None時 方便做inference。def network(images, labels=None): endpoints = conv_1 = slim.conv2d(images, 32, 3,3,1, padding=SAME) max_pool_1 = slim.max_pool2d(conv_1, 2,2,2,2, padding=SAME) conv_2 = slim.conv2d(max_pool_1, 64, 3,3,padding=SAME) max_pool_2 = slim.max_pool2d(conv_2, 2,2,2,2, padding=SAME) flatten = slim.flatten(max_pool_2) out = slim.fully_connected(flatten,3755, activation_fn=None) global_step = tf.Variable(initial_value=0) if labels is not None: loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(out, labels) train_op = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss, global_step=global_step) accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(out, 1), tf.argmax(labels, 1), tf.float32) tf.summary.scalar(loss, loss) tf.summary.scalar(accuracy, accuracy) merged_summary_op = tf.summary.merge_all() output_score = tf.nn.softmax(out) predict_val_top3, predict_index_top3 = tf.nn.top_k(output_score, k=3) endpointsglobal_step = global_step if labels is not None: endpointslabels = labels endpointstrain_op = train_op endpointsloss = loss endpointsaccuracy = accuracy endpointsmerged_summary_op = merged_summary_op endpointsoutput_score = output_score endpointspredict_val_top3 = predict_val_top3 endpointspredict_index_top3 = predict_index_top3 return endpointsTraintrain函數(shù)包括從已有checkpoint中restore,得到step,快速恢復(fù)訓(xùn)練過程,訓(xùn)練主要是每一次得到mini-batch,更新參數(shù),每隔eval_steps后做一次train batch的eval,每隔save_steps 后保存一次checkpoint。def train(): sess = tf.Session() file_labels = get_imagesfile(FLAGS.train_data_dir) images, labels, coord, threads = batch_data(file_labels, sess) endpoints = network(images, labels) saver = tf.train.Saver() sess.run(tf.global_variables_initializer() train_writer = tf.train.SummaryWriter(./log + /train,sess.graph) test_writer = tf.train.SummaryWriter(./log + /val) start_step = 0 if FLAGS.restore: ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) if ckpt: saver.restore(sess, ckpt) print restore from the checkpoint 0.format(ckpt) start_step += int(ckpt.split(-)-1) (:Training Start:) try: while not coord.should_stop(): # (step 0 start.format(i) start_time = time.time() _, loss_val, train_summary, step = sess.run(endpointstrain_op, endpointsloss, endpointsmerged_summary_op, endpointsglobal_step) train_writer.add_summary(train_summary, step) end_time = time.time() (the step 0 takes 1 loss 2.format(step, end_time-start_time, loss_val) if step FLAGS.max_steps: break # (the step 0 takes 1 loss 2.format(i, end_time-start_time, loss_val) if step % FLAGS.eval_steps = 1: accuracy_val,test_summary, step = sess.run(endpointsaccuracy, endpointsmerged_summary_op, endpointsglobal_step) test_writer.add_summary(test_summary, step) (=Eval a batch in Train data=) ( the step 0 accuracy 1.format(step, accuracy_val) (=Eval a batch in Train data=) if step % FLAGS.save_steps = 1: (Save the ckpt of 0.format(step) saver.save(sess, os.path.join(FLAGS.checkpoint_dir, my-model), global_step=endpointsglobal_step) except tf.errors.OutOfRangeError: # print =train finished= (=Train Finished=) saver.save(sess, os.path.join(FLAGS.checkpoint_dir, my-model), global_step=endpointsglobal_step) finally: coord.request_stop() coord.join(threads) sess.close()GraphLoss and AccuracyValidation訓(xùn)練完成之后,想對完成的模型在測試數(shù)據(jù)集上做一個評估,這里我也曾經(jīng)嘗試?yán)胋atch_data,將slice_input_producer中epoch設(shè)置為1,來做相關(guān)的工作,但是發(fā)現(xiàn)這里無法和train 共用,會出現(xiàn)epoch無初始化值的問題(train中傳epoch為None),所以這里自己寫了shuffle batch的邏輯,將測試集的images和labels通過feed_dict傳進到網(wǎng)絡(luò),得到模型的輸出, 然后做相關(guān)指標(biāo)的計算:def validation(): # it should be fixed by using placeholder with epoch num in train stage sess = tf.Session() file_labels = get_imagesfile(FLAGS.test_data_dir) test_size = len(file_labels) print test_size val_batch_size = FLAGS.val_batch_size test_steps = test_size / val_batch_size print test_steps # images, labels, coord, threads= batch_data(file_labels, sess) images = tf.placeholder(dtype=tf.float32, shape=None, 64, 64, 1) labels = tf.placeholder(dtype=32, shape=None,3755) # read batch images from file_labels # images_batch = np.zeros(128,64,64,1) # labels_batch = np.zeros(128,3755) # labels_batch020 = 1 # endpoints = network(images, labels) saver = tf.train.Saver() ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) if ckpt: saver.restore(sess, ckpt) # (restore from the checkpoint 0.format(ckpt) # (Start validation) final_predict_val = final_predict_index = groundtruth = for i in range(test_steps): start = i* val_batch_size end = (i+1)*val_batch_size images_batch = labels_batch = labels_max_batch = (=start validation on 0/1 batch=.format(i, test_steps) for j in range(start,end): image_path = file_labelsj0 temp_image = Image.open(image_path).convert(L) temp_image = temp_image.resize(FLAGS.image_size, FLAGS.image_size),Image.ANTIALIAS) temp_label = np.zeros(3755) label = int(file_labelsj1) # print label temp_labellabel = 1 # print =,np.asarray(temp_image).shape labels_batch.append(temp_label) # print =,np.asarray(temp_image).shape images_batch.append(np.asarray(temp_image)/255.0) labels_max_batch.append(label) # print images_batch images_batch = np.array(images_batch).reshape(-1, 64, 64, 1) labels_batch = np.array(labels_batch) batch_predict_val, batch_predict_index = sess.run(endpointspredict_val_top3, endpointspredict_index_top3, feed_dict=images:images_batch, labels:labels_batch) (=validation on 0/1 batch end=.format(i, test_steps) final_predict_val += batch_predict_val.tolist() final_predict_index += batch_predict_index.tolist() groundtruth += labels_max_batch sess.close() return final_predict_val, final_predict_index, groundtruth在訓(xùn)練20w個step之后,大概能達到在測試集上能夠達到:相信如果在網(wǎng)絡(luò)設(shè)計上多花點時間能夠在一定程度上提升accuracy和top 3 accuracy.有興趣的小伙伴們可以玩玩這個數(shù)據(jù)集。Inferencedef inference(image): temp_image = Image.open(image).convert(L)

溫馨提示

  • 1. 本站所有資源如無特殊說明,都需要本地電腦安裝OFFICE2007和PDF閱讀器。圖紙軟件為CAD,CAXA,PROE,UG,SolidWorks等.壓縮文件請下載最新的WinRAR軟件解壓。
  • 2. 本站的文檔不包含任何第三方提供的附件圖紙等,如果需要附件,請聯(lián)系上傳者。文件的所有權(quán)益歸上傳用戶所有。
  • 3. 本站RAR壓縮包中若帶圖紙,網(wǎng)頁內(nèi)容里面會有圖紙預(yù)覽,若沒有圖紙預(yù)覽就沒有圖紙。
  • 4. 未經(jīng)權(quán)益所有人同意不得將文件中的內(nèi)容挪作商業(yè)或盈利用途。
  • 5. 人人文庫網(wǎng)僅提供信息存儲空間,僅對用戶上傳內(nèi)容的表現(xiàn)方式做保護處理,對用戶上傳分享的文檔內(nèi)容本身不做任何修改或編輯,并不能對任何下載內(nèi)容負責(zé)。
  • 6. 下載文件中如有侵權(quán)或不適當(dāng)內(nèi)容,請與我們聯(lián)系,我們立即糾正。
  • 7. 本站不保證下載資源的準(zhǔn)確性、安全性和完整性, 同時也不承擔(dān)用戶因使用這些下載資源對自己和他人造成任何形式的傷害或損失。

評論

0/150

提交評論