@takane takane / createmltest.swift
Created at Fri Jun 09 13:55:55 JST 2023 - forked from ml2022/1874d747a461534b2314e9f6540f4584
SwiftからCreateMLを実行するプログラム
createmltest.swift
Raw
// createmltest.swift
// Usage: 
//      swift createmltest.swift <train folder path> <test folder path>
//
// reference:
// https://www.netguru.com/blog/createml-start-your-adventure-in-machine-learning-with-swift
//
// note:
// trainフォルダとtestフォルダをコマンドライン引数で指定
// mlmodelファイルをhomeのDesktopに書き込むように変更
//   maxIterations
//   augmentation
 
import CreateML
import Foundation
 
// Initializing the properly labeled training data from Resources folder.
let trainingData = MLImageClassifier.DataSource.labeledDirectories(at: URL(fileURLWithPath: "/Users/takane/Desktop/CreateMLTest/CreateMLTest/Data/image/train"))
 
// Initializing the classifier with a training data.
let classifier = try! MLImageClassifier(trainingData: trainingData,
                                        parameters: MLImageClassifier.ModelParameters(
                                            maxIterations:500, 
                                            augmentation:[.noise,.crop,.blur]))
 
// Evaluating training & validation accuracies.
let trainingAccuracy = (1.0 - classifier.trainingMetrics.classificationError) * 100
let validationAccuracy = (1.0 - classifier.validationMetrics.classificationError) * 100
 
// Initializing the properly labeled testing data from Resources folder.
let testingData = MLImageClassifier.DataSource.labeledDirectories(at: URL(fileURLWithPath: "/Users/takane/Desktop/CreateMLTest/CreateMLTest/Data/image/test"))
 
// Counting the testing evaluation.
let evaluationMetrics = classifier.evaluation(on: testingData)
let evaluationAccuracy = (1.0 - evaluationMetrics.classificationError) * 100
 
// Confusion matrix in order to see which labels were classified wrongly.
let confusionMatrix = evaluationMetrics.confusion
print("Confusion matrix: \(confusionMatrix)")
 
// Metadata for saving the model.
let metadata = MLModelMetadata(author: "S.Takane",
                               shortDescription: "Cats and Dogs",
                               version: "1.0")
 
// Saving the model. Remember to update the path.
try classifier.write(to: URL(fileURLWithPath: NSHomeDirectory()+"/Desktop"),
                     metadata: metadata)