@takane takane / createmltest.swift
Created at Fri Jun 09 13:55:55 JST 2023 - forked from ml2022/1874d747a461534b2314e9f6540f4584
SwiftからCreateMLを実行するプログラム
createmltest.swift
Raw
// createmltest.swift
// Usage: 
//      swift createmltest.swift <train folder path> <test folder path>
//
// reference:
// https://www.netguru.com/blog/createml-start-your-adventure-in-machine-learning-with-swift
//
// note:
// trainフォルダとtestフォルダをコマンドライン引数で指定
// mlmodelファイルをhomeのDesktopに書き込むように変更
// 20行目の augmentation:[...] の中にオプションを指定。 .rotationはエラーが起こる

import CreateML
import Foundation

// Initializing the properly labeled training data from Resources folder.
let trainingData = MLImageClassifier.DataSource.labeledDirectories(at: URL(fileURLWithPath: CommandLine.arguments[1]))

// Initializing the classifier with a training data.
let classifier = try! MLImageClassifier(trainingData: trainingData,
                                        parameters: MLImageClassifier.ModelParameters(maxIterations:1000, augmentation:[.crop, .blur, .exposure, .noise, .flip]))

// Evaluating training & validation accuracies.
let trainingAccuracy = (1.0 - classifier.trainingMetrics.classificationError) * 100
let validationAccuracy = (1.0 - classifier.validationMetrics.classificationError) * 100

// Initializing the properly labeled testing data from Resources folder.
let testingData = MLImageClassifier.DataSource.labeledDirectories(at: URL(fileURLWithPath: CommandLine.arguments[2]))

// Counting the testing evaluation.
let evaluationMetrics = classifier.evaluation(on: testingData)
let evaluationAccuracy = (1.0 - evaluationMetrics.classificationError) * 100

// Confusion matrix in order to see which labels were classified wrongly.
let confusionMatrix = evaluationMetrics.confusion

//print("Confusion matrix: \(confusionMatrix)")
print(evaluationMetrics)
print("Training Accuracy: \(trainingAccuracy)%")
print("Validation Accuracy: \(validationAccuracy)%")
print("Evaluation Accuracy: \(evaluationAccuracy)%")
print("")

// Metadata for saving the model.
let metadata = MLModelMetadata(author: "S.Takane",
                               shortDescription: "Cats and Dogs",
                               version: "1.0")

// Saving the model. Remember to update the path.
try classifier.write(to: URL(fileURLWithPath: NSHomeDirectory()+"/Desktop"),
                     metadata: metadata)