TensorFlow Lite Tutorial for Flutter: Image Classification
Learn how to use TensorFlow Lite in Flutter. Train your machine learning model with Teachable Machine and integrate the result into your Flutter mobile app. By Ken Lee.
Sign up/Sign in
With a free Kodeco account you can download source code, track your progress, bookmark, personalise your learner profile and more!
Create accountAlready a member of Kodeco? Sign in
Sign up/Sign in
With a free Kodeco account you can download source code, track your progress, bookmark, personalise your learner profile and more!
Create accountAlready a member of Kodeco? Sign in
Contents
TensorFlow Lite Tutorial for Flutter: Image Classification
30 mins
- Getting Started
- Brief Introduction to Machine Learning
- What is Machine Learning
- Training a Model: How it Works
- Understanding Tensor and TensorFlow Prediction
- Installing TensorFlow Lite in Flutter
- Creating an Image Classifier
- Importing the Model to Flutter
- Loading Classification Labels
- Importing TensorFlow Lite Model
- Implementing TensorFlow Prediction
- Pre-Processing Image Data
- Running the Prediction
- Post-Processing the Output Result
- Using the Classifier
- Picking an Image From the Device
- Initializing the Classifier
- Analyzing Images Using the Classifier
- Where to Go from Here?
Loading Classification Labels
Open lib/classifier/classifier.dart and import tflite_flutter_helper:
import 'package:tflite_flutter_helper/tflite_flutter_helper.dart';
Then add the following code after predict
:
static Future<ClassifierLabels> _loadLabels(String labelsFileName) async {
// #1
final rawLabels = await FileUtil.loadLabels(labelsFileName);
// #2
final labels = rawLabels
.map((label) => label.substring(label.indexOf(' ')).trim())
.toList();
debugPrint('Labels: $labels');
return labels;
}
Here’s what the above code does:
- Loads the labels using the file utility from tflite_flutter_helper.
- Removes the index number prefix from the labels you previously downloaded. For example, it changes 0 Rose to Rose.
Next, replace // TODO: _loadLabels
in loadWith
by calling _loadLabels
like so:
final labels = await _loadLabels(labelsFileName);
This code loads the label file.
Save the changes. There is nothing more to do with the labels now, so it’s time to run a test.
Build and run.
Look at the console output:
Congrats, you successfully parsed the model’s labels!
Importing TensorFlow Lite Model
Go to lib/classifier/classifier_model.dart and replace the contents with the following code:
import 'package:tflite_flutter/tflite_flutter.dart';
class ClassifierModel {
Interpreter interpreter;
List<int> inputShape;
List<int> outputShape;
TfLiteType inputType;
TfLiteType outputType;
ClassifierModel({
required this.interpreter,
required this.inputShape,
required this.outputShape,
required this.inputType,
required this.outputType,
});
}
ClassifierModel
stores all model-related data for your classifier. You’ll use the interpreter
to predict the results. inputShape
and outputShape
are shapes for the input and output data respectively while inputType
and outputType
are the data types of the input and output tensors.
Now, import the model from the file. Go to lib/classifier/classifier.dart and add the following code after _loadLabels
:
static Future<ClassifierModel> _loadModel(String modelFileName) async {
// #1
final interpreter = await Interpreter.fromAsset(modelFileName);
// #2
final inputShape = interpreter.getInputTensor(0).shape;
final outputShape = interpreter.getOutputTensor(0).shape;
debugPrint('Input shape: $inputShape');
debugPrint('Output shape: $outputShape');
// #3
final inputType = interpreter.getInputTensor(0).type;
final outputType = interpreter.getOutputTensor(0).type;
debugPrint('Input type: $inputType');
debugPrint('Output type: $outputType');
return ClassifierModel(
interpreter: interpreter,
inputShape: inputShape,
outputShape: outputShape,
inputType: inputType,
outputType: outputType,
);
}
Don’t forget to add the import import 'package:tflite_flutter/tflite_flutter.dart';
at the top.
Here’s what happens in the above code:
- Creates an interpreter with the provided model file — the interpreter is a tool to predict the result.
- Read the input and output shapes, which you’ll use to conduct pre-processing and post-processing of your data.
- Read the input and output types so that you’ll know what type of data you have.
Next, replace // TODO: _loadModel
in loadWith
with the following:
final model = await _loadModel(modelFileName);
The code above loads the model file.
Build and run. Look at the console output:
You successfully parsed the model! It’s a multi-dimensional array of float32
values.
Finally, for initialization, replace // TODO: build and return Classifier
in loadWith
with the following:
return Classifier._(labels: labels, model: model);
That builds your Classifier
instance, which PlantRecogniser
uses to recognize images the user provides.
Implementing TensorFlow Prediction
Before doing any prediction, you need to prepare the input.
You’ll write a method to convert the Flutter Image
object to TensorImage
, the tensor structure used by TensorFlow for images. You also need to modify the image to fit the required shape of the model.
Pre-Processing Image Data
With the help of tflite_flutter_helper, image processing is simple because the library provides several functions you can pull in to handle image reshaping.
Add the _preProcessInput
method to lib/classifier/classifier.dart:
TensorImage _preProcessInput(Image image) {
// #1
final inputTensor = TensorImage(_model.inputType);
inputTensor.loadImage(image);
// #2
final minLength = min(inputTensor.height, inputTensor.width);
final cropOp = ResizeWithCropOrPadOp(minLength, minLength);
// #3
final shapeLength = _model.inputShape[1];
final resizeOp = ResizeOp(shapeLength, shapeLength, ResizeMethod.BILINEAR);
// #4
final normalizeOp = NormalizeOp(127.5, 127.5);
// #5
final imageProcessor = ImageProcessorBuilder()
.add(cropOp)
.add(resizeOp)
.add(normalizeOp)
.build();
imageProcessor.process(inputTensor);
// #6
return inputTensor;
}
_preProcessInput
preprocesses the Image
object so that it becomes the required TensorImage
. These are the steps involved:
- Create the
TensorImage
and load the image data to it. - Crop the image to a square shape. You have to import
dart:math
at the top to use themin
function. - Resize the image operation to fit the shape requirements of the model.
- Normalize the value of the data. Argument
127.5
is selected because of your trained model’s parameters. You want to convert image’s pixel0-255
value to-1...1
range. - Create the image processor with the defined operation and preprocess the image.
- Return the preprocessed image.
Then, invoke the method inside predict(...)
at // TODO: _preProcessInput
:
final inputImage = _preProcessInput(image);
debugPrint(
'Pre-processed image: ${inputImage.width}x${image.height}, '
'size: ${inputImage.buffer.lengthInBytes} bytes',
);
You’ve implemented your pre-processing logic.
Build and run.
Pick an image from the gallery and look at the console:
You successfully converted the image to the model’s required shape!
Next, you’ll run the prediction.
Running the Prediction
Add the following code at // TODO: run TF Lite
to run the prediction:
// #1
final outputBuffer = TensorBuffer.createFixedSize(
_model.outputShape,
_model.outputType,
);
// #2
_model.interpreter.run(inputImage.buffer, outputBuffer.buffer);
debugPrint('OutputBuffer: ${outputBuffer.getDoubleList()}');
Here’s what happens in the code above:
-
TensorBuffer
stores the final scores of your prediction in raw format. - Interpreter reads the tensor image and stores the output in the buffer.
Build and run.
Select an image from your gallery and observe the console:
Great job! You successfully got an interpretive result from the model. Just a few more steps to make the results friendly for human users. That brings you to the next task: post-processing the result.
Post-Processing the Output Result
The TensorFlow output result is a similarity score for each label, and it looks like this:
[0.0, 0.2, 0.9, 0.0]
It’s a little hard to tell which value refers to which label unless you happened to create the model.
Add the following method to lib/classifier/classifier.dart:
List<ClassifierCategory> _postProcessOutput(TensorBuffer outputBuffer) {
// #1
final probabilityProcessor = TensorProcessorBuilder().build();
probabilityProcessor.process(outputBuffer);
// #2
final labelledResult = TensorLabel.fromList(_labels, outputBuffer);
// #3
final categoryList = <ClassifierCategory>[];
labelledResult.getMapWithFloatValue().forEach((key, value) {
final category = ClassifierCategory(key, value);
categoryList.add(category);
debugPrint('label: ${category.label}, score: ${category.score}');
});
// #4
categoryList.sort((a, b) => (b.score > a.score ? 1 : -1));
return categoryList;
}
Here’s the logic for your new post-processing method:
- Create an instance of
TensorProcessorBuilder
to parse and process the output. - Map output values to your labels.
- Build category instances with the list of
label
–score
records. - Sort the list to place the most likely result at the top.
Great, now you just need to invoke _postProcessOutput()
for the prediction.
Update predict(...)
so that it looks like the following:
ClassifierCategory predict(Image image) {
// Load the image and convert it to TensorImage for TensorFlow Input
final inputImage = _preProcessInput(image);
// Define the output buffer
final outputBuffer = TensorBuffer.createFixedSize(
_model.outputShape,
_model.outputType,
);
// Run inference
_model.interpreter.run(inputImage.buffer, outputBuffer.buffer);
// Post Process the outputBuffer
final resultCategories = _postProcessOutput(outputBuffer);
final topResult = resultCategories.first;
debugPrint('Top category: $topResult');
return topResult;
}
You implemented your new post-processing method in your TensorFlow output, so you get the first and most valuable result back.
Build and run.
Upload an image and see it correctly predicts the plant:
Congratulations! That was a good ride.
Next, you’ll learn how the Classifier
works to produce this result.