const config = require('../config')
const tf = require('@tensorflow/tfjs')
const utils = require('./utils')

async function predict(model, predictionData) {
    console.log(predictionData[0][predictionData[0].length - 1])
    console.log(config.LABEL_COLUMN_NAME)
    if(config.LABEL_COLUMN_NAME === predictionData[0][predictionData[0].length - 1]){
        const data = utils.removeColumn(predictionData, config.LABEL_COLUMN_NAME);
        return model.predict(data);
    } else{
        const dataWithoutHeader = predictionData.slice(1);
        const tensorData = tf.tensor2d(dataWithoutHeader);
        tensorData.print()
        return model.predict(tensorData);
    } 
}

function calculateAccuracy(labels, predictedLabels, threshold) {
    let correctCount = 0;
    for (let i = 0; i < labels.length; i++) {
        if (Math.abs(labels[i] - predictedLabels[i]) <= threshold) {
            correctCount += 1;
        }
    }
    return correctCount / labels.length;
}

async function train(Epochs, TrainingFile, epochCallback, updateEpochLoss, batchSize, threshold, trainingValidationSplit){
    const {data, labels } = utils.splitTable(TrainingFile, config.LABEL_COLUMN_NAME);

    console.debug('data shape:', data.shape);
    console.debug('labels shape:', labels.shape);

    // Logs the tensors the console
    if (config.DEBUG) {
      data.print();
      labels.print();
    }

    // TODO(yaakov): Consider renaming "Final" to better represent what it does
    // TODO(yaakov): Replace "First split" and "Second split" with more meaningful names

    // First split
    const { trainingData, trainingLabels, validationData, validationLabels } = utils.trainValidationSplit(data, labels, trainingValidationSplit);

    // Second split
    const { trainingData: trainingDataFinal, 
            trainingLabels: trainingLabelsFinal,     
            validationData: validationDataFinal, 
            validationLabels: validationLabelsFinal 
        } = utils.trainValidationSplit(validationData, validationLabels, 0.5); // second split should split the data in half

    //console.debug('Validation data shape: ', validationDataFinal.shape);
    //console.debug('Validation labels shape: ', validationLabelsFinal.shape);
    //console.debug('Training data shape: ', trainingDataFinal.shape);
    //console.debug('Training labels shape: ', trainingLabelsFinal.shape);

    validationDataFinal.print();
    trainingDataFinal.print();

    // The model
    const model = tf.sequential();
    model.add(tf.layers.dense({ units: 128, inputShape: [data.shape[1]], activation: 'relu' }));
    model.add(tf.layers.dense({ units: 64, activation: 'relu' }));
    model.add(tf.layers.dense({ units: 1, activation: 'linear' }));
    const optimizer = tf.train.adam(0.000008);
    model.compile({ optimizer, loss: 'meanSquaredError' });

    // Training loop

    const start = Date.now();

    const callbacks = {
        onEpochEnd: async (epoch, logs) => {
            const trainAccuracies = []
            const validationAccuracies = []
            // Calculate training accuracy
            const trainAccuracy = calculateAccuracy(trainingLabelsFinal.arraySync(), model.predict(trainingDataFinal).arraySync(), threshold) * 100
            trainAccuracies.push(trainAccuracy)
            // Calculate validation accuracy
            const validationAccuracy = calculateAccuracy(validationLabelsFinal.arraySync(), model.predict(validationDataFinal).arraySync(), threshold) * 100
            validationAccuracies.push(validationAccuracy)
            // Log accuracies
            console.debug('Epoch number: ' + (epoch + 1))
            console.debug('Training Accuracy:', trainAccuracy.toFixed(2) + '%')
            console.debug('Validation Accuracy:', validationAccuracy.toFixed(2) + '%')
            epochCallback(epoch + 1, Epochs);
            updateEpochLoss(logs.loss);
        },
    };

    await model.fit(trainingData, trainingLabels, {
        epochs: Epochs,
        verbose: 0,
        // TODO(yaakov): Why is this "Final" and not just plain?
        validationData: [validationDataFinal, validationLabelsFinal],
        batchSize: batchSize,
        callbacks: callbacks,
    });

    const end = Date.now();
    const runningTimeInMin = ((end - start) / (60 * 1000)).toFixed(2);
    console.debug(`Training time: ${runningTimeInMin} minutes`);
    const actual = await validationLabels.arraySync();
    const predict = await model.predict(validationData).arraySync();
    console.log(actual)
    console.log(predict.flat())
    const testAccuracy = await utils.calcAvgAccInPrecentages(
        actual,
        predict.flat());
    //console.debug(`Test cccuracy: ${(testAccuracy * 100).toFixed(2)}%`);

    return { model: model, testAccuracy: testAccuracy, runningTimeInMin: runningTimeInMin} ;
}

module.exports = {
    train,
    predict
}