import turicreate as tc
import re

# データセットの読み込み
train_data = tc.image_analysis.load_images('./Train9' , with_path=True)

# データセットの読み込み
test_data = tc.image_analysis.load_images('./Test9' , with_path=True)

# ラベル作成の関数
def makelabel(path):
    if re.search(r'/beaker', path):
        return 'beaker'
    elif re.search(r'/crucible', path):
        return 'crucible'
    elif re.search(r'/flask', path):
        return 'flask'
    elif re.search(r'/funnel', path):
        return 'funnel'
    elif re.search(r'/test tube', path):
        return 'test tube'
    elif re.search(r'/burette', path):
        return 'burette'
    elif re.search(r'/Eggplantflask', path):
        return 'Eggplantflask'
    elif re.search(r'/glassrod', path):
        return 'glassrod'
    elif re.search(r'/Komagomepipette', path):
        return 'Komagomepipette'
    elif re.search(r'/Messcylinder', path):
        return 'Messcylinder'
    elif re.search(r'/Messflask', path):
        return 'Messflask'
    elif re.search(r'/Messpipette', path):
        return 'Messpipette'
    elif re.search(r'/Petridish', path):
        return 'Petridish'
    elif re.search(r'/thermometer', path):
        return 'thermometer'
    elif re.search(r'/vollpipette', path):
        return 'vollpipette'
    else:
        return 'Watchglass'

#パス名からラベルを生成
train_data['label'] = train_data['path'].apply(lambda path: makelabel(path))

#パス名からラベルを生成
test_data['label'] = test_data['path'].apply(lambda path: makelabel(path))

train_data.explore()

#学習
model = tc.image_classifier.create(train_data, target= 'label', model='squeezenet_v1.1', max_iterations=100)

# 評価
metrics = model.evaluate(test_data)

# 評価データの正解率
print(metrics['accuracy'])

# TuriCreateモデルの保存
model.save('./ImageClassification.model')

# Core ML形式のモデルの保存
model.export_coreml('./ImageClassification.mlmodel')