-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtrain.py
More file actions
28 lines (22 loc) · 796 Bytes
/
train.py
File metadata and controls
28 lines (22 loc) · 796 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from model_zoo.trainer import BaseTrainer
from model_zoo.preprocess import standardize
from model_zoo import flags, datasets
flags.define('epochs', 100)
flags.define('model_class_name', 'HousePricePredictionModel')
flags.define('checkpoint_name', 'model.h5')
flags.define('checkpoint_save_weights_only', False)
class Trainer(BaseTrainer):
"""
Train Price Prediction Model.
"""
def data(self):
"""
Prepare train data.
:return:
"""
(x_train, y_train), (x_eval, y_eval) = datasets.boston_housing.load_data()
x_train, x_eval = standardize(x_train, x_eval)
train_data, eval_data = (x_train, y_train), (x_eval, y_eval)
return self.generator(*train_data), eval_data
if __name__ == '__main__':
Trainer().run()