Worum geht es?

In diesem Blog-Artikel möchte ich darüber berichten, wie ein trainiertes Keras-Modell in Tensorflow.js eingebunden werden kann, um mit dem Modell clientseitig und ausschließlich im Browser Prognosen durchführen zu können.

Das Modell

Da ich selbst meine ersten Schritte mit künstlicher Intelligenz und Tensorflow mache, habe ich mich für einen einfachen Use Case entschieden. Dieser wird oft als Einstieg verwendet wird. Mithilfe des MNIST-Datensatzes, welcher 60.000 Trainingsdaten für handgeschriebene Zahlen beinhaltet, soll das Modell eigens geschriebene Zahlen erkennen. Auf die Details des verwendeten Modells gehe ich in diesem Beitrag nicht ein, da dieses bereits auf unseren Blog erläutert wurden.

Konvertierung des Modells

Das letztendlich trainierte Modell ist bei Keras in Form einer HDF5-Datei (Dateiendung h5) abgebildet. Diese Datei enthält die Topologie sowie die Gewichte des Modells. Da Tensorflow.js eine solchen Datei nicht unterstützt, muss sie zunächst konvertiert werden.

Für diese Konvertierung bietet Google das Python-Paket „tensorflowjs“ an. Über eine bereitgestellte API kann das Modell direkt aus dem Python-Code, wo es trainiert wird, umgewandelt werden.

import tensorflowjs as tfjs
…
tfjs.converters.save_keras_model(model, ziel_ordner)

Eine weitere Möglichkeit ist, ein bereits bestehendes Modell mit dem Command-Line-Interface des Paketes umzuwandeln. Da ich das Modell bereits trainiert habe, habe ich diese Methode angewendet

tensorflowjs_converter --input_format keras \
                       pfad/zum/model.h5 \
                       pfad/zum/ziel_ordner

Nach dem Konvertieren wurden mehrere Dateien erstellt. Die Datei „model.json“ enthält die Topologie des Modells. Die Gewichte des Modells werden in bin-Dateien gespeichert. Um das Caching großer Modelle für Browser zu verbessern, werden die Gewichte in mehrere Dateien aufgeteilt. Standardmäßig sind die einzelnen Fragmente 4 MB groß. Über den Flag –weight_shard_size_bytes kann dieser Wert angepasst werden.

Neben Keras-Modellen können auch „Tensorflow SavedModels“, „Keras Saved Models“ und „Tensorflow Hub Modules“ konvertiert werden. Weitere Informationen zur Konvertierung können unter folgendem Link nachgelesen werden: https://github.com/tensorflow/tfjs-converter

Erstellung der Webseite

Um das Modell im Browser nutzen zu können, braucht es eine Webseite, auf der Ziffern gezeichnet werden können. Dazu habe ich eine simple Seite mit HTML erstellt und die JavaScript-Library p5.js eingebunden, welche Interaktionen mit Canvas-Elementen von HTML erleichtert. Nun kann man auf dieser Webseite mithilfe des Mauszeigers in einem Canvas gezeichnet werden.

Einbinden des Modells in die Webseite

Um das konvertierte Modell auf der Webseite nutzen zu können, muss zunächst die JavaScript-Library Tensorflow.js eingebunden werden. In meinem Fall habe ich die Library über NPM bezogen. Hierfür wird der Befehl npm install @tensorflow/tfjs oder für yarn yarn add @tensorflow/tfjs verwendet. Die Library kann auch statisch im HTML nachgeladen werden:

<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.0.0/dist/tf.min.js"></script>.

Wenn die Library über NPM bezogen wurde, muss mit import * as tf from ‚@tensorflow/tfjs‘ oder const tf = require(‚@tensorflow/tfjs‘) die Library in eine JavaScript-Datei eingebunden werden. Anschließend steht eine Funktion bereit, mit der das Modell geladen werden kann: const model = await tf.loadLayersModel(‚pfad/zum/model.json‘)

Um das Modell nutzen zu können, müssen die gezeichneten Bilder aus dem Canvas in Tensoren transformiert werden, welches das Modell versteht. Dazu stellt die Library eine Funktion bereit, in der ein ImageData-Objekt übergeben wird, das man aus dem Canvas erhält. Anschließend wird der Tensor umgeformt und nach float32 umgewandelt, sodass er den Ansprüchen des verwendeten Modells entspricht. Über die Funktion „predict“ berechnet Tensorflow.js mit dem erstellten Tensor die Prognose für die gezeichnete Zahl. Anschließend werden die Ergebnisse in die Form eines Arrays umgewandelt, welches für jede Ziffer eine Wahrscheinlichkeit enthält.

const tensor = tf.browser.fromPixels(imageData, 1)
  .reshape([1, 28, 28, 1])
  .cast('float32')
const output = model.predict(tensor);
const predictions = Array.from(output.dataSync());

Zur Veranschaulichung kann die erstellte Demo gerne ausprobiert werden.
Der Sourcecode ist ebenfalls auf GitHub verfügbar. Außerdem gibt es eine weitere Demo die bereits von Kollegen der MT AG gebaut wurde.

Jetzt teilen auf:

Jetzt kommentieren