-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathml_model_training.py
More file actions
65 lines (51 loc) · 2.13 KB
/
Copy pathml_model_training.py
File metadata and controls
65 lines (51 loc) · 2.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import numpy as np
import os
from tflite_model_maker.config import QuantizationConfig
from tflite_model_maker.config import ExportFormat
from tflite_model_maker import model_spec
from tflite_model_maker import object_detector
import tensorflow as tf
assert tf.__version__.startswith('2')
tf.get_logger().setLevel('ERROR')
from pprint import pprint #Pretty printing for output
def trainEfficientDetLite0Model(classes, trainPath, valPath, testPath, modelFileName):
trainPath = trainPath.strip('/').strip('\\')
valPath = valPath.strip('/').strip('\\')
testPath = testPath.strip('/').strip('\\')
train_data = object_detector.DataLoader.from_pascal_voc(
trainPath, #'images/train',
trainPath,
classes
)
validation_data = object_detector.DataLoader.from_pascal_voc(
valPath,
valPath,
classes
)
test_data = object_detector.DataLoader.from_pascal_voc(
testPath,
testPath,
classes
)
print("\nUsing an EfficientDet-Lite0 model for training with 320x320 image resolution.")
spec = object_detector.EfficientDetLite0Spec()
print("\nTraining starts......")
model = object_detector.create(train_data=train_data,
model_spec=spec,
validation_data=validation_data,
epochs=50,
batch_size=4,
train_whole_model=True)
print("\nEvaluating created model")
print("Evaluation result:")
result = model.evaluate(test_data)
pprint(result, width=10)
LABELS_FILENAME = 'labels.txt'
print("\nExport model to tflite-format")
model.export(export_dir='.', tflite_filename=modelFileName, label_filename=LABELS_FILENAME,
export_format=[ExportFormat.TFLITE, ExportFormat.LABEL])
print("\n\nEvaluating tflite-model")
print("Evaluation result:")
result = model.evaluate_tflite(modelFileName, test_data)
pprint(result, width=10)
return 0