This is the sequel to the previous post on How To Create Updatable Models Using Core ML 3.
With Core ML 3, training a Core ML model on a device is a lot easier than taming a dragon!
Prerequisites:
Plan of Action
Retrain a cat vs. dog classifier Core ML model on a device by relabelling predicted images with the opposite label.
Train the batch of relabelled images on the device itself with our updatable model.
Save the new updated Model in the application’s document directory on your device and use this new model for future predictions or retraining.
Final Destination
An image is worth a thousand words. A GIF is composed of thousands of images. Here’s the final outcome you’ll get by the end of this article.
As you can see in the screengrab, we allow retraining predicted images with the inverse label.
Relabelling and retraining the model doesn’t always guarantee a different prediction.
Now that we know our final result, let’s begin the journey of training your ML model on a device. Before we dive into code, let’s get a hang of the Core ML classes and API we’ll be using.
A Brief Look Into the Core ML API
MLModel
is the class that encapsulates the model.
We will be discussing the following important classes and protocols in the next sections:
MLFeatureValue
MLImageConstraints
MLFeatureProvider
MLBatchProvider
MLUpdateTask
MLFeatureValue
MLFeatureValue
acts as a wrapper for the data. The Core ML model accepts the inputs and outputs in the form of MLFeatureValue
.
MLFeatureValue
lets us directly use a CGImage
. Along with that, we can pass the image constraints for the model. It creates the CVPixelBuffer
from the CGImage
for you, thereby avoiding the need to write helper methods.
The following piece of code creates an MLFeatureValue
instance from an image.
let featureValue = try MLFeatureValue(cgImage: image.cgImage!, constraint: imageConstraint, options: nil) |
Now let’s look into MLImageConstraints
.
MLImageConstraints
MLImageConstraints
is responsible for feeding the correct size of the input image to the model. It contains the input information. In our case, that is the image size and image format.
We can easily retrieve the image constraint object from the model using the following piece of code:
let imageConstraint = model?.modelDescription.inputDescriptionsByName["image"]!.imageConstraint! |
We just need to pass the input name (“image”
, in our case) to the model description.
MLFeatureProvider
An MLFeatureValue
is not directly passed into the model. It needs to be wrapped inside the MLFeatureProvider
.
If you inspect the mlmodel
Swift file, the model implements the MLFeatureProvider
protocol. To access the MLFeatureValue
from MLFeatureProvider
, there is a featureValue
accessor method.
MLDictionaryFeatureProvider
is a convenience wrapper that holds the data in a dictionary format. It requires the input name ("image"
, in our case) as the key and MLFeatureValue
as the value.
If there are more than inputs, just add them in the same dictionary.
MLBatchProvider
This holds a collection of MLFeatureProviders
for batch processing.
We can hence predict multiple feature providers or train a batch of training inputs encapsulated in the MLBatchProvider
. In this article, we’ll be doing the latter.
An MLArrayBatchProviders
contains an array of batch providers.
MLUpdateTask
An MLUpdateTask
is responsible for updating the model with the new training inputs.
Required parameters
Model URL — The location of the compiled model (
mlmodelc
extension).Training data —
MLArrayBatchProviders
.Model configuration — Here we pass
MLModelConfiguration
. We can use the existing model’s configuration or customize it. For example, we can force the model to run on the CPU and/or GPU and/or neural engine.Completion handler — It returns the
context
from which we can access the updated model. Then we can write the model back to the documents directory.
Optional parameters
progressHandlers
— Here you passMLUpdateProgressHandlers
with the array of events you want to listen to, such as epoch start/end, training start/end.progressHandler
— This gets called whenever any of the events defined in the first case gets triggered.
To start the training, just call the resume()
function on the updateTask
instance.
Here’s a look at a pseudo code for training the data on a device:
let updateTask = try MLUpdateTask(forModelAt: updatableModelURL, trainingData: trainingData, configuration: model.configuration, completionHandler: { context in | |
} | |
updateTask.resume() |
Now that we’ve got an idea of the different components and their roles, let’s build our iOS application that trains the model on the device.
Code
Our storyboard
Load a model from a URL
First, let’s try to load our mlmodel
into the documents directory on a separate URL:
private func loadModel(url: URL) -> MLModel? { | |
do { | |
let config = MLModelConfiguration() | |
config.computeUnits = .all | |
return try MLModel(contentsOf: url, configuration: config) | |
} catch { | |
print("Error loading model: \(error)") | |
return nil | |
} | |
} | |
let modelURL = Bundle.main.url(forResource: "CatDogUpdatable", withExtension: "mlmodelc") | |
let updatableModel = loadModel(url: modelURL) |
Predict an image using MLModel
Now that we’ve got our MLModel
from the URL, we’ll run the prediction code, assuming we’ve got the image from the ImagePickerController
.
func predict(image: UIImage) -> Animal? { | |
let imageConstraint = model.modelDescription.inputDescriptionsByName["image"]!.imageConstraint! | |
do{ | |
let imageOptions: [MLFeatureValue.ImageOption: Any] = [ | |
.cropAndScale: VNImageCropAndScaleOption.scaleFill.rawValue | |
] | |
let featureValue = try MLFeatureValue(cgImage: image.cgImage!, constraint: imageConstraint, options: imageOptions) | |
let featureProviderDict = try MLDictionaryFeatureProvider(dictionary: ["image" : featureValue]) | |
let prediction = try updatableModel?.prediction(from: featureProviderDict) | |
let value = prediction?.featureValue(for: "classLabel")?.stringValue | |
if value == "Dog"{ | |
return .dog | |
} | |
else{ | |
return .cat | |
} | |
}catch(let error){ | |
print("error is \(error.localizedDescription)") | |
} | |
return nil | |
} |
We just pass in the UIImage
as a CGImage
to the MLFeatureValue
with the MLImageConstraints
of the model input and MLDictionaryFeatureProvider
runs the prediction on the MLModel
.
featureValue
returns a set of featureNames
. classLabel
, in our case, contains the label cat
or dog
.
We have a lookup dictionary of UIImage
and Label
, termed as imageLabelDictionary
. If we want to add an image to the training input, we set the image and the inverse of the predicted output (cat/dog) in the dictionary.
Next, we create a batch provider out of the imageLabelDictionary
.
Create a batch provider
Our batch provider creates an MLArrayBatchProvider
out of the TrainingInput
instances which require the image
as a CVPixelBuffer
and the classLabel
as either cat or dog.
private func batchProvider() -> MLArrayBatchProvider | |
{ | |
var batchInputs: [MLFeatureProvider] = [] | |
let imageOptions: [MLFeatureValue.ImageOption: Any] = [ | |
.cropAndScale: VNImageCropAndScaleOption.scaleFill.rawValue | |
] | |
for (image,label) in imageLabelDictionary { | |
do{ | |
let featureValue = try MLFeatureValue(cgImage: image.cgImage!, constraint: imageConstraint, options: imageOptions) | |
if let pixelBuffer = featureValue.imageBufferValue{ | |
let x = CatDogUpdatableTrainingInput(image: pixelBuffer, classLabel: label) | |
batchInputs.append(x) | |
} | |
} | |
catch(let error){ | |
print("error description is \(error.localizedDescription)") | |
} | |
} | |
return MLArrayBatchProvider(array: batchInputs) | |
} |
Thanks to MLFeatureValue
, we can easily retrieve the pixelBuffer
from the featureValue
function.
Retrieve URL of the MLModel
We need to pass the model URL to MLUpdateTask
. For that, we need to retrieve the URL from the application’s documents directory. We need to use the FileManager
.
The code is straightforward:
let fileManager = FileManager.default | |
let documentDirectory = try fileManager.url(for: .documentDirectory, in: .userDomainMask, appropriateFor:nil, create:true) | |
let modelURL = documentDirectory.appendingPathComponent("CatDog.mlmodelc") |
Now we are ready to train our model again with the new images.
let modelConfig = MLModelConfiguration() | |
modelConfig.computeUnits = .cpuAndGPU | |
let updateTask = try MLUpdateTask(forModelAt: modelURL, trainingData: batchProvider(), configuration: modelConfig, | |
progressHandlers: MLUpdateProgressHandlers(forEvents: [.trainingBegin,.epochEnd], | |
progressHandler: { (contextProgress) in | |
print(contextProgress.event) | |
}) { (finalContext) in | |
if finalContext.task.error?.localizedDescription == nil | |
{ | |
let fileManager = FileManager.default | |
do { | |
let documentDirectory = try fileManager.url(for: .documentDirectory, in: .userDomainMask, appropriateFor:nil, create:true) | |
let fileURL = documentDirectory.appendingPathComponent("CatDog.mlmodelc") | |
try finalContext.model.write(to: fileURL) | |
self.updatableModel = self.loadModel(url: fileURL) | |
} catch(let error) { | |
print("error is \(error.localizedDescription)") | |
} | |
} | |
}) | |
updateTask.resume() |
finalContext.model.write(to: fileURL)
overwrites the model present at the URL in the documents directory. This Core ML model was set to run one epoch only.
Full Source Code
That concludes the Core ML on-device training. The full source code below merges all the above concepts into a workable iOS Application. Along with that, the models and Python scripts are available in the GitHub repository.