import numpy as np import pandas as pd import matplotlib.pyplot as plt import random from typing import TextIO from bs4 import BeautifulSoup import re import math import itertools import copy from collections import Counter data = pd.read_csv("all_data/train_all_data_label.csv") print("len(data) ", len(data)) data_test = pd.read_csv("all_data/test_all_data_label.csv") print("len(data_test) ", len(data_test)) # data.dropna() print(data.head()) # data['label'] = data["topicname"].apply(lambda x: 0 if x=="earn" else 1) print(data.head()) from sklearn.model_selection import train_test_split # X_train, X_test, y_train, y_test = train_test_split(data["text"], data["label"], random_state=1) X_train = data["text"].values y_train = data["label"].values X_test = data_test["text"].values y_test = data_test["label"].values # print("X_train = ", X_train) # print("y_train = ", y_train) from sklearn.feature_extraction.text import CountVectorizer cv = CountVectorizer(strip_accents="ascii", token_pattern=u'(?ui)\\b\\w*[a-z]+\\w*\\b', lowercase=True, stop_words='english') X_train_cv = cv.fit_transform(X_train) X_test_cv = cv.transform(X_test) # print("X_train_cv = ", X_train_cv) # print("y_train = ", y_train) word_freq_df = pd.DataFrame(X_train_cv.toarray(), columns=cv.get_feature_names()) top_words_df = pd.DataFrame(word_freq_df.sum()).sort_values(0, ascending=False) from sklearn.naive_bayes import MultinomialNB naive_bayes = MultinomialNB() naive_bayes.fit(X_train_cv, y_train) predictions = naive_bayes.predict(X_test_cv) from sklearn.metrics import accuracy_score print( 'Accuracy score: ', accuracy_score(y_test, predictions)) print("y_test = ", y_test) print("predictions = ", predictions) top_10_class_list = ['earn', 'acq', 'crude', 'grain', 'money-supply', 'money-fx', 'coffee', 'trade', 'veg-oil', 'interest'] from sklearn.metrics import confusion_matrix import matplotlib.pyplot as plt import seaborn as sns cm = confusion_matrix(y_test, predictions) print("cm ", cm) sns.heatmap(cm, square=True, annot=True, cmap='Blues_r', cbar=False, xticklabels=[top_10_class_list], yticklabels=[top_10_class_list]) plt.xlabel('true label') plt.ylabel('predicted label') plt.show()