RM新时代网站-首页

0
  • 聊天消息
  • 系統(tǒng)消息
  • 評論與回復
登錄后你可以
  • 下載海量資料
  • 學習在線課程
  • 觀看技術視頻
  • 寫文章/發(fā)帖/加入社區(qū)
會員中心
創(chuàng)作中心

完善資料讓更多小伙伴認識你,還能領取20積分哦,立即完善>

3天內不再提示

如何針對涂鴉識別問題構建基于RNN的識別器

Tensorflowers ? 來源:未知 ? 作者:胡薇 ? 2018-11-27 09:13 ? 次閱讀

Quick, Draw!是一款游戲;在這個游戲中,玩家要接受一項挑戰(zhàn):繪制幾個圖形,看看計算機能否識別玩家繪制的是什么。

Quick, Draw!的識別操作 由一個分類器執(zhí)行,它接收用戶輸入(用 (x, y) 中的點筆畫序列表示),然后識別用戶嘗試涂鴉的圖形所屬的類別。

在本教程中,我們將展示如何針對此問題構建基于 RNN 的識別器。該模型將結合使用卷積層、LSTM 層和 softmax 輸出層對涂鴉進行分類:

上圖顯示了我們將在本教程中構建的模型的結構。輸入為一個涂鴉,用 (x, y, n) 中的點筆畫序列表示,其中 n 表示點是否為新筆畫的第一個點。

然后,模型將應用一系列一維卷積,接下來,會應用 LSTM 層,并將所有 LSTM 步的輸出之和饋送到 softmax 層,以便根據我們已知的涂鴉類別來決定涂鴉的分類。

本教程使用的數(shù)據來自真實的Quick, Draw!游戲,這些數(shù)據是公開提供的。此數(shù)據集包含 5000 萬幅涂鴉,涵蓋 345 個類別。

運行教程代碼

要嘗試本教程的代碼,請執(zhí)行以下操作:

安裝 TensorFlow(如果尚未安裝的話)

下載教程代碼

下載數(shù)據(TFRecord格式),然后解壓縮。如需詳細了解如何獲取原始 Quick, Draw!數(shù)據以及如何將數(shù)據轉換為TFRecord文件,請參閱下文

使用以下命令執(zhí)行教程代碼,以訓練本教程中所述的基于 RNN 的模型。請務必調整路徑,使其指向第 3 步中下載的解壓縮數(shù)據

python train_model.py \ --training_data=rnn_tutorial_data/training.tfrecord-?????-of-????? \ --eval_data=rnn_tutorial_data/eval.tfrecord-?????-of-????? \ --classes_file=rnn_tutorial_data/training.tfrecord.classes

教程詳情

下載數(shù)據

我們將本教程中要使用的數(shù)據放在了包含TFExamples的TFRecord文件中。您可以從以下位置下載這些數(shù)據:http://download.tensorflow.org/data/quickdraw_tutorial_dataset_v1.tar.gz(大約 1GB)。

或者,您也可以從 Google Cloud 下載ndjson格式的原始數(shù)據,并將這些數(shù)據轉換為包含TFExamples的TFRecord文件,如下一部分中所述。

可選:下載完整的 QuickDraw 數(shù)據

完整的Quick, Draw!數(shù)據集可在 Google Cloud Storage 上找到,此數(shù)據集是按類別劃分的ndjson文件。您可以在 Cloud Console 中瀏覽文件列表。

要下載數(shù)據,我們建議使用gsutil下載整個數(shù)據集。請注意,原始 .ndjson 文件需要下載約 22GB 的數(shù)據。

然后,使用以下命令檢查 gsutil 安裝是否成功以及您是否可以訪問數(shù)據存儲分區(qū):

gsutil ls -r "gs://quickdraw_dataset/full/simplified/*"

系統(tǒng)會輸出一長串文件,如下所示:

gs://quickdraw_dataset/full/simplified/The Eiffel Tower.ndjsongs://quickdraw_dataset/full/simplified/The Great Wall of China.ndjsongs://quickdraw_dataset/full/simplified/The Mona Lisa.ndjsongs://quickdraw_dataset/full/simplified/aircraft carrier.ndjson...

之后,創(chuàng)建一個文件夾并在其中下載數(shù)據集。

mkdir rnn_tutorial_datacd rnn_tutorial_datagsutil -m cp "gs://quickdraw_dataset/full/simplified/*" .

下載過程需要花費一段時間,且下載的數(shù)據量略超 23GB。

可選:轉換數(shù)據

要將ndjson文件轉換為TFRecord文件(包含tf.train.Example樣本),請運行以下命令。

python create_dataset.py --ndjson_path rnn_tutorial_data \ --output_path rnn_tutorial_data

此命令會將數(shù)據存儲在TFRecord文件的 10 個分片中,每個類別有 10000 項用于訓練數(shù)據,有 1000 項用于評估數(shù)據。

下文詳細說明了該轉換過程。

原始 QuickDraw 數(shù)據的格式為ndjson文件,其中每行包含一個如下所示的 JSON 對象:

{"word":"cat","countrycode":"VE","timestamp":"2017-03-02 23:25:10.07453 UTC","recognized":true,"key_id":"5201136883597312","drawing":[ [ [130,113,99,109,76,64,55,48,48,51,59,86,133,154,170,203,214,217,215,208,186,176,162,157,132], [72,40,27,79,82,88,100,120,134,152,165,184,189,186,179,152,131,114,100,89,76,0,31,65,70] ],[ [76,28,7], [136,128,128] ],[ [76,23,0], [160,164,175] ],[ [87,52,37], [175,191,204] ],[ [174,220,246,251], [134,132,136,139] ],[ [175,255], [147,168] ],[ [171,208,215], [164,198,210] ],[ [130,110,108,111,130,139,139,119], [129,134,137,144,148,144,136,130] ],[ [107,106], [96,113] ]]}

在構建我們的分類器時,我們只關注 “word” 和 “drawing” 字段。在解析 ndjson 文件時,我們使用一個函數(shù)逐行處理它們,該函數(shù)可將drawing字段中的筆畫轉換為大小為[number of points, 3](包含連續(xù)點的差異)的張量。此函數(shù)還會以字符串形式返回類別名稱。

def parse_line(ndjson_line): """Parse an ndjson line and return ink (as np array) and classname.""" sample = json.loads(ndjson_line) class_name = sample["word"] inkarray = sample["drawing"] stroke_lengths = [len(stroke[0]) for stroke in inkarray] total_points = sum(stroke_lengths) np_ink = np.zeros((total_points, 3), dtype=np.float32) current_t = 0 for stroke in inkarray: for i in [0, 1]: np_ink[current_t:(current_t + len(stroke[0])), i] = stroke[i] current_t += len(stroke[0]) np_ink[current_t - 1, 2] = 1 # stroke_end # Preprocessing. # 1. Size normalization. lower = np.min(np_ink[:, 0:2], axis=0) upper = np.max(np_ink[:, 0:2], axis=0) scale = upper - lower scale[scale == 0] = 1 np_ink[:, 0:2] = (np_ink[:, 0:2] - lower) / scale # 2. Compute deltas. np_ink = np_ink[1:, 0:2] - np_ink[0:-1, 0:2] return np_ink, class_name

由于我們希望數(shù)據在寫入時進行隨機處理,因此我們以隨機順序從每個類別文件中讀取數(shù)據并寫入隨機分片。

對于訓練數(shù)據,我們讀取每個類別的前 10000 項;對于評估數(shù)據,我們讀取每個類別接下來的 1000 項。

然后,將這些數(shù)據變形為[num_training_samples, max_length, 3]形狀的張量。接下來,我們用屏幕坐標確定原始涂鴉的邊界框并標準化涂鴉的尺寸,使涂鴉具有單位高度。

最后,我們計算連續(xù)點之間的差異,并將它們存儲為VarLenFeature(位于tensorflow.Example中的ink鍵下)。另外,我們將class_index存儲為單一條目FixedLengthFeature,將ink的shape存儲為長度為 2 的FixedLengthFeature。

定義模型

要定義模型,我們需要創(chuàng)建一個新的Estimator。如需詳細了解 Estimator,建議您閱讀此教程。

要構建模型,我們需要執(zhí)行以下操作:

將輸入調整回原始形狀,其中小批次通過填充達到其內容的最大長度。除了 ink 數(shù)據之外,我們還擁有每個樣本的長度和目標類別。這可通過函數(shù)_get_input_tensors實現(xiàn)

將輸入傳遞給_add_conv_layers中的一系列卷積層

將卷積的輸出傳遞到_add_rnn_layers中的一系列雙向 LSTM 層。最后,將每個時間步的輸出相加,針對輸入生成一個固定長度的緊湊嵌入

在_add_fc_layers中使用 softmax 層對此嵌入進行分類

代碼如下所示:

inks, lengths, targets = _get_input_tensors(features, targets)convolved = _add_conv_layers(inks)final_state = _add_rnn_layers(convolved, lengths)logits =_add_fc_layers(final_state)

_get_input_tensors

要獲得輸入特征,我們先從特征字典獲得形狀,然后創(chuàng)建大小為[batch_size](包含輸入序列的長度)的一維張量。ink 作為稀疏張量存儲在特征字典中,我們將其轉換為密集張量,然后變形為[batch_size, ?, 3]。最后,如果傳入目標,我們需要確保它們存儲為大小為[batch_size]的一維張量。

代碼如下所示:

shapes = features["shape"]lengths = tf.squeeze( tf.slice(shapes, begin=[0, 0], size=[params["batch_size"], 1]))inks = tf.reshape( tf.sparse_tensor_to_dense(features["ink"]), [params["batch_size"], -1, 3])if targets is not None: targets = tf.squeeze(targets)

_add_conv_layers

您可以通過params字典中的參數(shù)num_conv和conv_len配置所需的卷積層數(shù)量和過濾器長度。

輸入是一個每個點維數(shù)都是 3 的序列。我們將使用一維卷積,將 3 個輸入特征視為通道。這意味著輸入為[batch_size, length, 3]張量,而輸出為[batch_size, length, number_of_filters]張量。

convolved = inksfor i in range(len(params.num_conv)): convolved_input = convolved if params.batch_norm: convolved_input = tf.layers.batch_normalization( convolved_input, training=(mode == tf.estimator.ModeKeys.TRAIN)) # Add dropout layer if enabled and not first convolution layer. if i > 0 and params.dropout: convolved_input = tf.layers.dropout( convolved_input, rate=params.dropout, training=(mode == tf.estimator.ModeKeys.TRAIN)) convolved = tf.layers.conv1d( convolved_input, filters=params.num_conv[i], kernel_size=params.conv_len[i], activation=None, strides=1, padding="same", name="conv1d_%d" % i)return convolved, lengths

_add_rnn_layers

我們將卷積的輸出傳遞給雙向 LSTM 層,對此我們使用 contrib 的輔助函數(shù)。

outputs, _, _ = contrib_rnn.stack_bidirectional_dynamic_rnn( cells_fw=[cell(params.num_nodes) for _ in range(params.num_layers)], cells_bw=[cell(params.num_nodes) for _ in range(params.num_layers)], inputs=convolved, sequence_length=lengths, dtype=tf.float32, scope="rnn_classification")

請參閱代碼以了解詳情以及如何使用CUDA加速實現(xiàn)。

要創(chuàng)建一個固定長度的緊湊嵌入,我們需要將 LSTM 的輸出相加。我們首先將其中的序列不含數(shù)據的批次區(qū)域設為 0。

mask = tf.tile( tf.expand_dims(tf.sequence_mask(lengths, tf.shape(outputs)[1]), 2), [1, 1, tf.shape(outputs)[2]])zero_outside = tf.where(mask, outputs, tf.zeros_like(outputs))outputs = tf.reduce_sum(zero_outside, axis=1)

_add_fc_layers

將輸入的嵌入傳遞至全連接層,之后將此層用作 softmax 層。

tf.layers.dense(final_state, params.num_classes)

損失、預測和優(yōu)化器

最后,我們需要添加一個損失函數(shù)、一個訓練操作和預測來創(chuàng)建ModelFn:

cross_entropy = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=targets, logits=logits))# Add the optimizer.train_op = tf.contrib.layers.optimize_loss( loss=cross_entropy, global_step=tf.train.get_global_step(), learning_rate=params.learning_rate, optimizer="Adam", # some gradient clipping stabilizes training in the beginning. clip_gradients=params.gradient_clipping_norm, summaries=["learning_rate", "loss", "gradients", "gradient_norm"])predictions = tf.argmax(logits, axis=1)return model_fn_lib.ModelFnOps( mode=mode, predictions={"logits": logits, "predictions": predictions}, loss=cross_entropy, train_op=train_op, eval_metric_ops={"accuracy": tf.metrics.accuracy(targets, predictions)})

訓練和評估模型

要訓練和評估模型,我們可以借助EstimatorAPI 的功能,并使用ExperimentAPI 輕松運行訓練和評估操作:

estimator = tf.estimator.Estimator( model_fn=model_fn, model_dir=output_dir, config=config, params=model_params) # Train the model. tf.contrib.learn.Experiment( estimator=estimator, train_input_fn=get_input_fn( mode=tf.contrib.learn.ModeKeys.TRAIN, tfrecord_pattern=FLAGS.training_data, batch_size=FLAGS.batch_size), train_steps=FLAGS.steps, eval_input_fn=get_input_fn( mode=tf.contrib.learn.ModeKeys.EVAL, tfrecord_pattern=FLAGS.eval_data, batch_size=FLAGS.batch_size), min_eval_frequency=1000)

請注意,本教程只是用一個相對較小的數(shù)據集進行簡單演示,目的是讓您熟悉遞歸神經網絡和 Estimator 的 API。如果在大型數(shù)據集上嘗試,這些模型可能會更強大。

當模型完成 100 萬個訓練步后,分數(shù)最高的候選項的準確率預計會達到 70% 左右。請注意,這種程度的準確率足以構建 Quick, Draw! 游戲,由于該游戲的動態(tài)特性,用戶可以在系統(tǒng)準備好識別之前調整涂鴉。此外,如果目標類別顯示的分數(shù)高于固定閾值,該游戲不會僅使用分數(shù)最高的候選項,而且會將某個涂鴉視為正確的涂鴉。

聲明:本文內容及配圖由入駐作者撰寫或者入駐合作網站授權轉載。文章觀點僅代表作者本人,不代表電子發(fā)燒友網立場。文章及其配圖僅供工程師學習之用,如有內容侵權或者其他違規(guī)問題,請聯(lián)系本站處理。 舉報投訴
  • 神經網絡
    +關注

    關注

    42

    文章

    4771

    瀏覽量

    100713
  • 識別器
    +關注

    關注

    0

    文章

    20

    瀏覽量

    7579

原文標題:Quick, Draw! 涂鴉分類遞歸神經網絡

文章出處:【微信號:tensorflowers,微信公眾號:Tensorflowers】歡迎添加關注!文章轉載請注明出處。

收藏 人收藏

    評論

    相關推薦

    USB3.0的識別問

    做了一個USB3.0集線,現(xiàn)在遇到這個問題,USB3.0無法識別,插拔幾次后可以識別,接上其他設備也能正常工作,求高手幫忙啊,怎么處理USB3.0識別問
    發(fā)表于 10-29 11:36

    輪掃按鍵識別問

    大俠出來相求,每一個按鍵都可以唯一被識別嗎?機理是什么?
    發(fā)表于 07-27 16:58

    2812識別問

    用2812+cpld采集圖像然后再用2812識別,這個圖像識別很簡單,只是識別圖像中有幾條條紋??梢宰鰡??求解。
    發(fā)表于 04-16 11:08

    語音識別問

    各位大神,我想完成用SPCE061A來實現(xiàn)非特定人的語音識別技術,并能夠使得發(fā)出的命令能在LCD上顯示,不知有沒有能夠指導一下的,大概的框架和模塊,拜托各位了。。。
    發(fā)表于 01-06 22:47

    請教 LD3320 語音識別問

    在X寶買了一塊LD3320 模塊,用的是并行通訊,讀寫寄存都正常,啟動識別后有中斷, 識別結果寄存(0xBA)一直是0 . 是什么問題呀? 有沒人做成功的.分享下經驗!!! 謝謝
    發(fā)表于 03-28 13:43

    OCR識別問

    我用圖像助手訓練的時候能識別數(shù)字,但是訓練完后還是不能識別?為什么~~求大神告知一下下
    發(fā)表于 12-07 11:21

    DHCP識別問題如何解決

    我有一些DHCP服務不使用和諧網絡棧來識別單元的問題??磥恚绻疫B接一臺筆記本電腦到服務,它是公認的罰款。是否有少量信息用于識別?我注意到,筆記本電腦與和聲棧相比,發(fā)送了很多東西
    發(fā)表于 05-11 13:21

    如何解決網絡無法識別問

    網絡問題分類網絡無法識別問題還是比較好排查,但是如果涉及到網絡丟包牽扯的環(huán)節(jié)太多了比如交換芯片是否異常,對方的工作模式是否正常、網絡隔離變壓是否正常、CPU占用率、設備中斷影響先排除網絡環(huán)境和對方設備、在確認設備問題比如phy的時鐘是否重疊、phy的流控是否開啟等等..
    發(fā)表于 12-23 06:08

    離線語音識別和控制的工作原理及應用

    :   1.信號采集   離線語音識別系統(tǒng)的第一步是信號采集。聲音信號通過麥克風(傳感)以電信號的形式被捕捉到,這是后續(xù)處理的基礎。   2.預處理   預處理階段包括去除噪聲、回聲消除、降噪等處理
    發(fā)表于 11-07 18:01

    USB硬盤的系統(tǒng)識別問

      1、 如果系統(tǒng)裝的是win98,如不能被正確識別(即使安裝了USB2.0通用驅動也識別不了),這種情況下要檢查一下你的移動硬盤是否供電不足,如果供電不足就會出現(xiàn)“咳咳”的聲
    發(fā)表于 08-31 17:19 ?1040次閱讀

    貼片電容壞了怎么識別

    貼片電容如何識別?識別方法有哪些?,最近網上出現(xiàn)很多的貼片電容識別問題,很多人因為對貼片電容的容值識別不了解,導致失誤的機率提高。下面小編分享一下貼片電容的
    發(fā)表于 05-10 14:48 ?1.2w次閱讀

    USB智能識別IC可解決傳統(tǒng)USB口的識別問

    USB智能識別IC(PL515,PL513),適用于車充,充電器,移動電源等 USB口輸出供電方案。 USB智能識別IC,是用來解決傳統(tǒng)USB口的識別電阻,識別電阻做的
    的頭像 發(fā)表于 10-15 14:20 ?6618次閱讀
    USB智能<b class='flag-5'>識別</b>IC可解決傳統(tǒng)USB口的<b class='flag-5'>識別問</b>題

    HID_CDC復合設備在WIN10的識別問

    HID_CDC復合設備在WIN10的識別問題(電源技術發(fā)展綜述)-本文以STM32F405為例,詳細說明上HID_CDC復合設備在WIN10的識別問題。
    發(fā)表于 08-04 18:23 ?20次下載
    HID_CDC復合設備在WIN10的<b class='flag-5'>識別問</b>題

    STM32F0的USART波特率自動識別問

    電子發(fā)燒友網站提供《STM32F0的USART波特率自動識別問題.pdf》資料免費下載
    發(fā)表于 08-01 11:00 ?2次下載
    STM32F0的USART波特率自動<b class='flag-5'>識別問</b>題

    Purple Pi OH固件的芯片信息識別問題說明

    開源鴻蒙硬件方案領跑者觸覺智能本文適用于在PurplePiOH固件的芯片信息識別問題說明。觸覺智能的PurplePiOH鴻蒙開源主板,是華為Laval官方社區(qū)主薦的一款鴻蒙開發(fā)主板。該主板主要針對
    的頭像 發(fā)表于 06-26 08:32 ?258次閱讀
    Purple Pi OH固件的芯片信息<b class='flag-5'>識別問</b>題說明
    RM新时代网站-首页