diff --git a/Ipek_plot.png b/Ipek_plot.png new file mode 100644 index 0000000..ec89330 Binary files /dev/null and b/Ipek_plot.png differ diff --git a/analyzer.py b/analyzer.py index f57ef85..9bd193c 100644 --- a/analyzer.py +++ b/analyzer.py @@ -5,8 +5,20 @@ from utils.constants import * import pickle from sklearn import metrics +from sklearn.utils.multiclass import unique_labels +import matplotlib.pyplot as plt from typing import List +import numpy as np +from torch import nn + +from statsmodels.stats.contingency_tables import mcnemar +from plots import save_ipek_plot + + +# from mlxtend.evaluate import permutation_test + + class Analyzer: # input: both network models # return average loss, acc; etc. @@ -23,14 +35,14 @@ def __init__(self, self.model.eval() def soft_voting(self, probs1, probs2): - _, predictions = ((probs1 + probs2) / 2).max(dim=-1) - return predictions - + print(probs1) + return (probs1 + probs2) / 2 + def calculate_metrics( - self, - targets: List, - predictions: List, - average: str = "weighted"): + self, + targets: List, + predictions: List, + average: str = "weighted"): if sum(predictions) == 0: return 0, 0, 0 @@ -42,15 +54,164 @@ def calculate_metrics( return f1, precision, recall + def create_contingency_table(self, targets, predictions1, predictions2): + assert len(targets) == len(predictions1) + assert len(targets) == len(predictions2) + + contingency_table = np.zeros((2, 2)) + + targets_length = len(targets) + contingency_table[0, 0] = sum([targets[i] == predictions1[i] and targets[i] == predictions2[i] for i in + range(targets_length)]) # both predictions are correct + contingency_table[0, 1] = sum([targets[i] == predictions1[i] and targets[i] != predictions2[i] for i in + range(targets_length)]) # predictions1 is correct and predictions2 is wrong + contingency_table[1, 0] = sum([targets[i] != predictions1[i] and targets[i] == predictions2[i] for i in + range(targets_length)]) # predictions1 is wrong and predictions2 is correct + contingency_table[1, 1] = sum([targets[i] != predictions1[i] and targets[i] != predictions2[i] for i in + range(targets_length)]) # both predictions are wrong + + return contingency_table + + def calculate_mcnemars_test(self, targets, predictions1, predictions2): + contingency_table = self.create_contingency_table( + targets, + predictions1, + predictions2) + + result = mcnemar(contingency_table, exact=True) + return result.pvalue + + def calculate_confusion_matrix( + self, + targets, + predictions, + classes, + analysis_folder, + normalize=False, + plot_matrix=True, + title=None): + """ + This function prints and plots the confusion matrix. + Normalization can be applied by setting `normalize=True`. + """ + # Compute confusion matrix + cm = metrics.confusion_matrix(targets, predictions) + # Only use the labels that appear in the data + labels = unique_labels(targets, predictions) + classes = classes[labels] + if normalize: + cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] + + ax = None + if plot_matrix: + ax = self.plot_confusion_matrix(cm, classes, analysis_folder, normalize, title) + + return cm, ax + + def plot_confusion_matrix( + self, + cm, + classes, + analysis_folder, + normalize=False, + title=None, + print_scores=True, + cmap=plt.cm.Blues): + + fig, ax = plt.subplots() + im = ax.imshow(cm, interpolation='nearest', cmap=cmap) + ax.figure.colorbar(im, ax=ax) + # We want to show all ticks... + ax.set(xticks=np.arange(cm.shape[1]), + yticks=np.arange(cm.shape[0]), + # ... and label them with the respective list entries + xticklabels=classes, yticklabels=classes, + title=title, + ylabel='True label', + xlabel='Predicted label') + + ax.set_ylim(4.5, -0.5) # fix the classes + + # Rotate the tick labels and set their alignment. + plt.setp(ax.get_xticklabels(), rotation=45, ha="right", + rotation_mode="anchor") + + # Loop over data dimensions and create text annotations. + + if print_scores: + fmt = '.2f' if normalize else 'd' + thresh = cm.max() / 2. + for i in range(cm.shape[0]): + for j in range(cm.shape[1]): + ax.text(j, i, format(cm[i, j], fmt), + ha="center", va="center", + color="white" if cm[i, j] > thresh else "black") + + fig.tight_layout() + fig.savefig(os.path.join(analysis_folder, f'confusion_matrix_{title}')) + + return ax + + def compute_confusion_matrix( + self, + targets, + combined_predictions, + classifier_predictions, + analysis_folder): + + classes = np.array(['Pop', 'Hip-Hop', 'Rock', 'Metal', 'Country']) + combined_cm, _ = self.calculate_confusion_matrix(targets, combined_predictions, classes, analysis_folder, + normalize=False, title='Combined') + lstm_cm, _ = self.calculate_confusion_matrix(targets, classifier_predictions, classes, analysis_folder, + normalize=False, title='LSTM') + + diff_cm = combined_cm - lstm_cm + ones = np.ones(diff_cm.shape, dtype=np.int32) * (-1) + ones += np.eye(diff_cm.shape[0], dtype=np.int32) * 2 + diff_cm = ones * diff_cm + + self.plot_confusion_matrix( + diff_cm, + classes, + analysis_folder, + normalize=False, + title='Difference', + cmap=plt.cm.RdYlGn, + print_scores=False) + + plt.show() + + def compute_significance(self, targets, combined_predictions, classifier_predictions): + mcnemars_p_value = self.calculate_mcnemars_test(targets, classifier_predictions, combined_predictions) + alpha_value = 0.05 + mcnemars_significant = mcnemars_p_value < alpha_value + print(f'Mcnemars: {mcnemars_significant} | p-value: {mcnemars_p_value}') + + def compute_f1(self, targets, combined_predictions, classifier_predictions, vaes_predictions): + combined_f1, combined_precision, combined_recall = self.calculate_metrics(targets, combined_predictions) + classifier_f1, classifier_precision, classifier_recall = self.calculate_metrics(targets, classifier_predictions) + vae_f1, vae_precision, vae_recall = self.calculate_metrics(targets, vaes_predictions) + + print(f'Combined F1: {combined_f1}\nLSTM F1: {classifier_f1}\nVAE F1: {vae_f1}') + + def ensure_analyzer_filesystem(self): + analysis_folder = os.path.join('local_data', 'analysis') + if not os.path.exists(analysis_folder): + os.mkdir(analysis_folder) + + return analysis_folder def analyze_misclassifications(self, test_logs): + if test_logs is not None: - with open('logs1k.pickle', 'wb') as handle: + with open('logs_full_on_full.pickle', 'wb') as handle: pickle.dump(test_logs, handle, protocol=pickle.HIGHEST_PROTOCOL) else: - with open('logs1k.pickle', 'rb') as handle: + with open('logs_full_on_full.pickle', 'rb') as handle: test_logs = pickle.load(handle) + analysis_folder = self.ensure_analyzer_filesystem() + combined_scores = torch.stack(test_logs['final_scores']).view(-1, 5) classifier_scores = torch.stack(test_logs['combination']['classifier_scores']).view(-1, 5) vaes_scores = torch.stack(test_logs['combination']['vaes_scores']).view(-1, 5) @@ -60,8 +221,6 @@ def analyze_misclassifications(self, test_logs): _, classifier_predictions = classifier_scores.max(dim=-1) _, vaes_predictions = vaes_scores.max(dim=-1) - - # combined_predictions = self.soft_voting(vaes_scores, classifier_scores) # print('targets', targets) # print('combine', combined_predictions) # print('classif', classifier_predictions) @@ -71,32 +230,116 @@ def analyze_misclassifications(self, test_logs): combined_compare = combined_predictions.eq(targets) vaes_compare = vaes_predictions.eq(targets) - classifier_misfire_indices = (classifier_compare == 0).nonzero() # get misclassifications - vae_improved = vaes_compare[classifier_misfire_indices].float().mean() - print('VAE classified', vae_improved, 'of the LSTM misclassifications correctly.') - - # print('Elbo values', vaes_scores) - print('Accuracies:' '\n-Combined:', combined_compare.float().mean().item(), '\n-Base Classifier:', classifier_compare.float().mean().item(), '\n-Classify By Elbo:', vaes_compare.float().mean().item()) + self.uncertainty_analysis(vaes_scores, classifier_scores, targets, combined_scores) + + ''' + F1 score + ''' targets = targets.detach().tolist() combined_predictions = combined_predictions.tolist() classifier_predictions = classifier_predictions.tolist() vaes_predictions = vaes_predictions.tolist() - combined_f1, combined_precision, combined_recall = self.calculate_metrics(targets, combined_predictions) - classifier_f1, classifier_precision, classifier_recall = self.calculate_metrics(targets, classifier_predictions) + print("----------------------------------------------") + self.compute_f1(targets, combined_predictions, classifier_predictions, vaes_predictions) - print(f'Combined F1: {combined_f1}\nClassifier F1: {classifier_f1}') + print("----------------------------------------------") + self.compute_significance(targets, combined_predictions, classifier_predictions) + + print("----------------------------------------------") + self.compute_confusion_matrix(targets, combined_predictions, classifier_predictions, analysis_folder) # check if combination correctly classified these? check how many # print(combined_compare[classifier_misfire_indices]) # print(classifier_misfire_indices) + # IPEK PLOT + + classifier_misfire_indices = (classifier_compare == 0).nonzero() # get misclassifications + combined_misfire_indices = (combined_compare == 0).nonzero() # get misclassifications + vaes_misfire_indices = (vaes_compare == 0).nonzero() # get misclassifications + + len_of_dataset = len(classifier_compare.tolist()) + + # Compare LSTM with VAE + vae_right_class_wrong = vaes_compare[classifier_misfire_indices].tolist().count([1]) / len_of_dataset + vae_wrong_class_wrong = classifier_compare[vaes_misfire_indices].tolist().count([0]) / len_of_dataset + vae_wrong_class_right = classifier_compare[vaes_misfire_indices].tolist().count([1]) / len_of_dataset + + # Compare LSTM with Combined + comb_right_class_wrong = combined_compare[classifier_misfire_indices].tolist().count([1]) / len_of_dataset + comb_wrong_class_wrong = classifier_compare[combined_misfire_indices].tolist().count([0]) / len_of_dataset + comb_wrong_class_right = classifier_compare[combined_misfire_indices].tolist().count([1]) / len_of_dataset + + lstm_classifier = classifier_compare.tolist().count(1) / len_of_dataset + + save_ipek_plot([lstm_classifier, 1 - lstm_classifier, 0, 0], + [1 - vae_wrong_class_wrong - vae_wrong_class_right - + vae_right_class_wrong, vae_wrong_class_right, + vae_right_class_wrong, vae_wrong_class_wrong], + [1 - comb_wrong_class_wrong - comb_wrong_class_right - + comb_right_class_wrong, comb_wrong_class_right, + comb_right_class_wrong, comb_wrong_class_wrong], + 'Ipek_plot') + + def uncertainty_analysis(self, vaes_scores, classifier_scores, targets, combined_scores): + + _, combined_predictions = combined_scores.max(dim=-1) + _, classifier_predictions = classifier_scores.max(dim=-1) + _, vaes_predictions = vaes_scores.max(dim=-1) + + classifier_compare = classifier_predictions.eq(targets) + combined_compare = combined_predictions.eq(targets) + vaes_compare = vaes_predictions.eq(targets) + + ''' + uncertainty analyses + ''' + + vaes_scores_softmax = nn.Softmax(dim=-1)(vaes_scores) + classifier_predictions_indices, _ = classifier_scores.max(dim=-1) + classifier_prediction_values = classifier_scores[np.arange(0, len(classifier_scores)), + classifier_predictions_indices.long()] + + classifier_uncertain_indices = ((classifier_prediction_values < 0.50).eq( + classifier_prediction_values > 0.00)).nonzero() + + # vae_scores_for_uncertain = vaes_scores[classifier_uncertain_indices] + vae_scores_for_uncertain, pred_vae = vaes_scores_softmax[classifier_uncertain_indices.long()].max(dim=-1) + classifier_uncertain_scores, pred_class = classifier_scores[classifier_uncertain_indices.long()].max(dim=-1) + true = targets[classifier_uncertain_indices.long()] + + print('LSTM is uncertain in', len(classifier_uncertain_indices) / len(classifier_scores), 'samples.') + classifier_uncertain_indices_correct = classifier_compare[classifier_uncertain_indices].nonzero() + classifier_uncertain_indices_false = (classifier_compare[classifier_uncertain_indices] == 0).nonzero() + print('-', len(classifier_uncertain_indices_false) / len(classifier_uncertain_indices), + 'of these are misclassifications.') + + classifier_uncertain_correct_VAE = vaes_compare[classifier_uncertain_indices_correct] + classifier_uncertain_false_VAE = vaes_compare[classifier_uncertain_indices_false] + + print('- -', classifier_uncertain_correct_VAE.float().mean().item(), + 'of the CORRECT uncertain classifications are correctly classified by the VAE.') + print('- -', classifier_uncertain_false_VAE.float().mean().item(), + 'of the uncertain MISclassifications are correctly classified by the VAE.') + + classifier_uncertain_correct_Combined = combined_compare[classifier_uncertain_indices_correct] + classifier_uncertain_false_Combined = combined_compare[classifier_uncertain_indices_false] + print('- - -', classifier_uncertain_correct_Combined.float().mean().item(), + 'of the CORRECT uncertain classifications are correctly classified by the Combined Model.') + print('- - -', classifier_uncertain_false_Combined.float().mean().item(), + 'of the uncertain MISclassifications are correctly classified by the Combined Model.') + # print('cla', classifier_uncertain_scores.tolist()) + # print('vae', vae_scores_for_uncertain.tolist()) + # print('cla', pred_class.tolist()) + # print('vae', pred_vae.tolist()) + # print('tru', true.tolist()) diff --git a/jobs/train_sentence_vae.sh b/jobs/train_sentence_vae.sh index 96113a8..ecbcaea 100644 --- a/jobs/train_sentence_vae.sh +++ b/jobs/train_sentence_vae.sh @@ -18,4 +18,4 @@ export LD_LIBRARY_PATH=/hpc/eb/Debian9/cuDNN/7.1-CUDA-8.0.44-GCCcore-5.4.0/lib64 for genre in 'Pop' 'Rock' 'Hip-Hop' 'Metal' 'Country' do srun python3 -u main.py --generator SentenceVAE --dataset_class LyricsRawDataset --loss VAELoss --batch_size 16 --device cuda --eval_freq 100 --embedding_size 256 --hidden_dim 64 --genre $genre --run_name 'sentence-vae-genre-'$genre >> 'output/train-sentence-vae-genre-'$genre'-seed-42.out' -done +done \ No newline at end of file diff --git a/logs1k.pickle b/logs1k.pickle new file mode 100644 index 0000000..9dd790e Binary files /dev/null and b/logs1k.pickle differ diff --git a/main.py b/main.py index 3fb7f18..86cf1f4 100644 --- a/main.py +++ b/main.py @@ -89,9 +89,9 @@ def main(arguments: argparse.Namespace): # if we are in train mode.. if arguments.test_mode: - tester = Tester(model, data_loader_test, device=device, data_loader_sentence=data_loader_sentenceVAE) - test_logs = tester.test() - # test_logs = None + # tester = Tester(model, data_loader_test, device=device, data_loader_sentence=data_loader_sentenceVAE) + # test_logs = tester.test() + test_logs = None if arguments.analysis: analyzer = Analyzer(model, device=device, num_classes=arguments.num_classes) analyzer.analyze_misclassifications(test_logs) diff --git a/plots.py b/plots.py new file mode 100644 index 0000000..4517def --- /dev/null +++ b/plots.py @@ -0,0 +1,52 @@ +import numpy as np +import matplotlib.pyplot as plt + +def save_ipek_plot(lstm_numbers, vae_numbers, combined_numbers, name): + + category_names = ['Correctly Classified by LSTM', 'Misclassified by LSTM', 'Other model correct, LSTM wrong', 'Both wrong'] + results = { + 'LSTM': lstm_numbers, + 'VAE': vae_numbers, + 'Combined': combined_numbers + } + + + """ + Parameters + ---------- + results : dict + A mapping from question labels to a list of answers per category. + It is assumed all lists contain the same number of entries and that + it matches the length of *category_names*. + category_names : list of str + The category labels. + """ + labels = list(results.keys()) + data = np.array(list(results.values())) + data_cum = data.cumsum(axis=1) + category_colors = plt.get_cmap('RdYlGn')( + np.linspace(0.15, 0.85, 4)) + category_colors = np.concatenate((category_colors, category_colors), axis=0) + + fig, ax = plt.subplots(figsize=(9.2, 5)) + ax.invert_yaxis() + ax.xaxis.set_visible(False) + ax.set_xlim(0, np.sum(data, axis=1).max()) + + for i, (colname, color) in enumerate(zip(category_names, category_colors)): + widths = data[:, i] + starts = data_cum[:, i] - widths + ax.barh(labels, widths, left=starts, height=0.5, + label=colname, color=color) + xcenters = starts + widths / 2 + + r, g, b, _ = color + text_color = 'white' if r * g * b < 0.5 else 'darkgrey' + for y, (x, c) in enumerate(zip(xcenters, widths)): + if c > 0.025: + ax.text(x, y, str(c), ha='center', va='center', + color=text_color) + ax.legend(ncol=len(category_names), bbox_to_anchor=(0, 1), + loc='lower left', fontsize='small') + + plt.savefig(name +'.png') \ No newline at end of file diff --git a/preprocessing/lyrics_preprocessing.py b/preprocessing/lyrics_preprocessing.py index 8c1c1ae..51fc486 100644 --- a/preprocessing/lyrics_preprocessing.py +++ b/preprocessing/lyrics_preprocessing.py @@ -18,6 +18,7 @@ def save_dataset_text(song_entries, embeddings_folder_path, filename): for song_entry in song_entries: embeddings_file.write(f'{song_entry.lyrics}\n') + ensure_current_directory() main_path = os.path.join('local_data', 'data') @@ -64,7 +65,7 @@ def save_dataset_text(song_entries, embeddings_folder_path, filename): song_entries_by_genre[row[4]].append(song_entry) genres = list(song_entries_by_genre.keys()) -songs_limit = 100 +songs_limit = 13000 for genre in genres: song_entries_by_genre[genre] = song_entries_by_genre[genre][:songs_limit] @@ -90,6 +91,26 @@ def save_dataset_text(song_entries, embeddings_folder_path, filename): validation_song_entries = sorted([item for value in validation_data.values() for item in value], key=lambda song: len(song.lyrics)) test_song_entries = sorted([item for value in test_data.values() for item in value], key=lambda song: len(song.lyrics)) +empties_train, empties_test, empties_val = [],[],[] +for i, song in enumerate(train_song_entries): + if song.lyrics == '': + empties_train.append(i) +for i, song in enumerate(test_song_entries): + if song.lyrics == '': + empties_test.append(i) +for i, song in enumerate(validation_song_entries): + if song.lyrics == '': + empties_val.append(i) + +for index in sorted(empties_train, reverse=True): + del train_song_entries[index] +indexes = [2, 3, 5] +for index in sorted(empties_test, reverse=True): + del test_song_entries[index] +indexes = [2, 3, 5] +for index in sorted(empties_val, reverse=True): + del validation_song_entries[index] + lines_counter = 0 for song in train_song_entries: song.start_index = lines_counter