Cross Validation in Keras

If my goal is to fine-tune the network for the entire dataset

It is not clear what you mean by “fine-tune”, or even what exactly is your purpose for performing cross-validation (CV); in general, CV serves one of the following purposes:

  • Model selection (choose the values of hyperparameters)
  • Model assessment

Since you don’t define any search grid for hyperparameter selection in your code, it would seem that you are using CV in order to get the expected performance of your model (error, accuracy etc).

Anyway, for whatever reason you are using CV, the first snippet is the correct one; your second snippet

model = None
model = create_model()
for train, test in kFold.split(X, Y):
    train_evaluate(model, X[train], Y[train], X[test], Y[test])

will train your model sequentially over the different partitions (i.e. train on partition #1, then continue training on partition #2 etc), which essentially is just training on your whole data set, and it is certainly not cross-validation…

That said, a final step after the CV which is often only implied (and frequently missed by beginners) is that, after you are satisfied with your chosen hyperparameters and/or model performance as given by your CV procedure, you go back and train again your model, this time with the entire available data.

Leave a Comment