在這篇文章中,我們將利用 TensorFlow.js,D3.js 和 Web 的力量使訓(xùn)練模型的過(guò)程可視化,以預(yù)測(cè)棒球數(shù)據(jù)中的壞球(藍(lán)色區(qū)域)和好球(橙色區(qū)域)。 隨著我們的進(jìn)展,我們將模型在整個(gè)訓(xùn)練過(guò)程中理解的打擊區(qū)域可視化。您可以通過(guò)訪問(wèn)此 Observable 筆記本在瀏覽器中運(yùn)行此模型。
注:Observable鏈接
https://beta.observablehq.com/@nkreeger/visualizing-ml-training-using-tensorflow-js-and-baseball-d
如果你不熟悉棒球的擊球區(qū),這里有一篇詳細(xì)的文章。
上面的 GIF 可視化神經(jīng)網(wǎng)絡(luò)學(xué)習(xí)調(diào)用壞球(藍(lán)色區(qū)域)和好球(橙色區(qū)域)在每個(gè)訓(xùn)練步驟之后,熱圖會(huì)根據(jù)模型的預(yù)測(cè)進(jìn)行更新
使用 Observable 直接在瀏覽器中運(yùn)行此模型。
注:文章鏈接
https://beta.observablehq.com/@nkreeger/visualizing-ml-training-using-tensorflow-js-and-baseball-d
體育運(yùn)動(dòng)中的高級(jí)指標(biāo)
當(dāng)今的職業(yè)體育環(huán)境中充斥著大量的數(shù)據(jù)。這些數(shù)據(jù)被團(tuán)隊(duì),業(yè)余愛(ài)好者和粉絲應(yīng)用于各種用例中。感謝像 TensorFlow 這樣的框架 - 這些數(shù)據(jù)集已準(zhǔn)備好應(yīng)用于機(jī)器學(xué)習(xí)。
美國(guó)職業(yè)棒球大聯(lián)盟先進(jìn)媒體(MLBAM)的 PITCHf/x
美國(guó)職業(yè)棒球大聯(lián)盟先進(jìn)媒體(MLBAM)發(fā)布了一個(gè)可供公眾研究的大型數(shù)據(jù)集。該數(shù)據(jù)集包含有關(guān)過(guò)去幾年在美國(guó)職業(yè)棒球大聯(lián)盟比賽中投擲的投球的傳感器信息。 利用這個(gè)數(shù)據(jù)集,我們已編寫(xiě)了一個(gè)包含 5,000 個(gè)樣本的訓(xùn)練集(2,500 個(gè)壞球和 2,500 個(gè)好球)。
以下是訓(xùn)練數(shù)據(jù)中前幾個(gè)字段的示例:
注:示例鏈接
https://gist.github.com/nkreeger/01b5386b522b0cd1f22bc864320f3084#file-baseball-training-data-sample-csv
以下是針對(duì)打擊區(qū)域繪制的訓(xùn)練數(shù)據(jù)的樣子。藍(lán)點(diǎn)標(biāo)記為壞球,橙點(diǎn)標(biāo)記為好球(此為大聯(lián)盟裁判員稱(chēng)謂):
利用 TensorFlow.js 構(gòu)建模型
TensorFlow.js 將機(jī)器學(xué)習(xí)引入 JavaScript 和 Web。 我們將利用這個(gè)很棒的框架來(lái)構(gòu)建一個(gè)深度神經(jīng)網(wǎng)絡(luò)模型。這個(gè)模型將能夠按大聯(lián)盟裁判的精準(zhǔn)度來(lái)稱(chēng)呼好球和壞球。
輸入 Input
該模型在 PITCHf / x 的以下字段中進(jìn)行了訓(xùn)練:
協(xié)調(diào)球越過(guò)本壘的位置('px'和'pz')。
擊球手站在壘的哪一側(cè)。
擊球區(qū)(擊球手的軀干)的高度,以英尺為單位。
擊球區(qū)底部的高度(擊球手的膝蓋)以英尺為單位。
裁判所稱(chēng)的投球(好球或壞球)的實(shí)際標(biāo)簽。
結(jié)構(gòu) Architecture
該模型將通過(guò)使用 TensorFlow.js 圖層 API 定義。Layers API 基于 Keras,對(duì)以前使用過(guò)該框架的人來(lái)說(shuō)應(yīng)該很熟悉:
1const model = tf.sequential();
2
3// Two fully connected layers with dropout between each:
4model.add(tf.layers.dense({units: 24, activation: 'relu', inputShape: [5]}));
5model.add(tf.layers.dropout({rate: 0.01}));
6model.add(tf.layers.dense({units: 16, activation: 'relu'}));
7model.add(tf.layers.dropout({rate: 0.01}));
8
9// Only two classes: "strike" and "ball":
10model.add(tf.layers.dense({units: 2, activation: 'softmax'}));
11
12model.compile({
13optimizer: tf.train.adam(0.01),
14loss: 'categoricalCrossentropy',
15metrics: ['accuracy']
16});
加載和準(zhǔn)備數(shù)據(jù)
精選的訓(xùn)練集可通過(guò)GitHub gist 獲得。需要下載此數(shù)據(jù)集才能開(kāi)始將 CSV 數(shù)據(jù)轉(zhuǎn)換為 TensorFlow.js 用于訓(xùn)練的格式。
注:GitHub gist 鏈接
https://gist.github.com/nkreeger/43edc6e6daecc2cb02a2dd3293a08f29
1const data = [];
2csvData.forEach((values) => {
3// 'logit' data uses the 5 fields:
4const x = [];
5x.push(parseFloat(values.px));
6x.push(parseFloat(values.pz));
7x.push(parseFloat(values.sz_top));
8x.push(parseFloat(values.sz_bot));
9x.push(parseFloat(values.left_handed_batter));
10// The label is simply 'is strike' or 'is ball':
11const y = parseInt(values.is_strike, 10);
12data.push({x: x, y: y});
13});
14// Shuffle the contents to ensure the model does not always train on the same
15// sequence of pitch data:
16tf.util.shuffle(data);
解析 CSV 數(shù)據(jù)后,需要將 JS 類(lèi)型轉(zhuǎn)換為 Tensor 批次進(jìn)行培訓(xùn)和評(píng)估。有關(guān)此過(guò)程的詳細(xì)信息,請(qǐng)參閱代碼實(shí)驗(yàn)室。TensorFlow.js 團(tuán)隊(duì)正在開(kāi)發(fā)一種新的 Data API,以便將來(lái)更容易獲取。
注:代碼實(shí)驗(yàn)室
https://beta.observablehq.com/@nkreeger/visualizing-ml-training-using-tensorflow-js-and-baseball-d#batches
訓(xùn)練模型
讓我們把這一切都整合在一起吧。定義了模型,準(zhǔn)備好了訓(xùn)練數(shù)據(jù),現(xiàn)在我們已經(jīng)準(zhǔn)備好開(kāi)始訓(xùn)練了。以下異步方法訓(xùn)練一批訓(xùn)練樣本并更新熱圖:
1// Trains and reports loss+accuracy for one batch of training data:
2async function trainBatch(index) {
3const history = await model.fit(batches[index].x, batches[index].y, {
4epochs: 1,
5shuffle: false,
6validationData: [batches[index].x, batches[index].y],
7batchSize: CONSTANTS.BATCH_SIZE
8});
9
10// Don't block the UI frame by using tf.nextFrame()
11await tf.nextFrame();
12updateHeatmap();
13await tf.nextFrame();
14}
可視化模型的準(zhǔn)確性
使用來(lái)自均勻放置在本壘板上方的 4 x 4 英尺柵格的預(yù)測(cè)矩陣來(lái)構(gòu)建熱圖。在每個(gè)訓(xùn)練步驟之后將該矩陣傳遞到模型中以檢查模型的準(zhǔn)確度。使用 D3 庫(kù)將該預(yù)測(cè)的結(jié)果呈現(xiàn)為熱圖。
構(gòu)建預(yù)測(cè)矩陣
熱圖中使用的預(yù)測(cè)矩陣從本壘板的中間開(kāi)始,向左和向右各延伸 2 英尺。它的范圍也從本壘板的底部到 4 英尺高。擊打區(qū)樣本位于本壘板上方 1.5 至 3.5 英尺之間。下圖有助于讓這些 2d 窗格可視化:
該視覺(jué)顯示了打擊區(qū)域和預(yù)測(cè)矩陣與本壘板和游戲區(qū)域相關(guān)的位置
將預(yù)測(cè)矩陣與模型一起使用
每個(gè)批次在模型中訓(xùn)練之后,預(yù)測(cè)矩陣被傳遞到模型中用以請(qǐng)求矩陣中的好球或壞球預(yù)測(cè):
1function predictZone() {
2const predictions = model.predictOnBatch(predictionMatrix.data);
3const values = predictions.dataSync();
4
5// Sort each value so the higher prediction is the first element in the array:
6const results = [];
7let index = 0;
8for (let i = 0; i < values.length; i++) { ? ?
9let list = [];
10list.push({value: values[index++], strike: 0});
11list.push({value: values[index++], strike: 1});
12list = list.sort((a, b) => b.value - a.value);
13results.push(list);
14}
15return results;
16}
熱圖與 D3
現(xiàn)在可以使用 D3 顯示預(yù)測(cè)結(jié)果。 來(lái)自 50x50 網(wǎng)格中的每一個(gè)元素將在 SVG 中呈現(xiàn)為 10px x 10px 的矩形。每個(gè)矩形的顏色取決于預(yù)測(cè)結(jié)果(好球或者壞球)以及模型對(duì)該結(jié)果的確定程度(范圍從 50%-100%)。 以下代碼段顯示了如何從 D3 svg 矩形分組更新數(shù)據(jù):
1function updateHeatmap() {
2rects.data(generateHeatmapData());
3rects
4.attr('x', (coord) => { return scaleX(coord.x) * CONSTANTS.HEATMAP_SIZE; })
5.attr('y', (coord) => { return scaleY(coord.y) * CONSTANTS.HEATMAP_SIZE; })
6.attr('width', CONSTANTS.HEATMAP_SIZE)
7.attr('height', CONSTANTS.HEATMAP_SIZE)
8.style('fill', (coord) => {
9if (coord.strike) {
10return strikeColorScale(coord.value);
11} else {
12return ballColorScale(coord.value);
13}
14});
15}
有關(guān)使用 D3 繪制熱圖的完整詳細(xì)信息,請(qǐng)參閱此部分。
注:此部分鏈接
https://beta.observablehq.com/@nkreeger/visualizing-ml-training-using-tensorflow-js-and-baseball-d#colorDomain
總結(jié)
網(wǎng)絡(luò)上有許多令人驚嘆的第三方庫(kù)和工具,可用于創(chuàng)建視覺(jué)效果。將這些與機(jī)器學(xué)習(xí)的強(qiáng)大功能與 TensorFlow.js 相結(jié)合,開(kāi)發(fā)人員能夠創(chuàng)建一些非常新奇有趣的演示。
-
神經(jīng)網(wǎng)絡(luò)
+關(guān)注
關(guān)注
42文章
4771瀏覽量
100713 -
機(jī)器學(xué)習(xí)
+關(guān)注
關(guān)注
66文章
8406瀏覽量
132561 -
tensorflow
+關(guān)注
關(guān)注
13文章
329瀏覽量
60527
原文標(biāo)題:棒球比賽中是好球還是壞球?TensorFlow.js 已經(jīng)知道
文章出處:【微信號(hào):tensorflowers,微信公眾號(hào):Tensorflowers】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。
發(fā)布評(píng)論請(qǐng)先 登錄
相關(guān)推薦
評(píng)論