当前位置: 代码迷 >> 综合 >> keras报错 : ValueError: logits and labels must have the same shape
  详细解决方案

keras报错 : ValueError: logits and labels must have the same shape

热度:52   发布时间:2023-12-08 07:23:52.0

keras报错 :ValueError: logits and labels must have the same shape

  • 问题背景
  • 问题原因
  • 解决办法
  • 整体代码

问题背景

在一步一步搭建 文本情感分类器的过程时,将数据处理成向量的形式;
之后送入到 MLP模型中进行拟合训练;
发生错误:
ValueError: logits and labels must have the same shape ((None, 1) vs (None, 2))

问题原因

原因:

  • logits 和 labels 需要有相同的形状
    也就是我们的标签和逻辑量不匹配;因此才会报错

解决办法

去掉:将数值标签转化为逻辑变量的操作。

## ytrain, ytest = to_categorical(ytrain), to_categorical(ytest)

把这行给隐去;

整体代码

import numpy as np
import string
import re
from os import listdir
from collections import Counter
from nltk.corpus import stopwords
from tensorflow.keras.preprocessing.text import Tokenizer# load doc into memory
def load_doc(filename):# open the file as read onlyfile = open(filename, 'r', encoding='utf-8')# read all texttext = file.read()# close the filefile.close()return text# turn a doc into clean tokens
def clean_doc(doc):# split into tokens by white spacetokens = doc.split()# convert to lower casetokens = [w.lower() for w in tokens]# prepare regex for char filteringre_punc = re.compile('[%s]' % re.escape(string.punctuation))# remove punctuation from each wordtokens = [re_punc.sub('', w) for w in tokens]# remove remaining tokens that are not alphabetictokens = [word for word in tokens if word.isalpha()]# filter out stop wordsstop_words = set(stopwords.words('english'))tokens = [w for w in tokens if not w in stop_words]# filter out short tokenstokens = [word for word in tokens if len(word) > 1]return tokens######### Reviews to Lines of Tokens ########################
# load doc, clean and return line of tokens
def doc_to_line(filename, vocab):# load the docdoc = load_doc(filename)# clean doctokens = clean_doc(doc)# filter by vocabtokens = [w for w in tokens if w in vocab]return ' '.join(tokens)# load all docs in a directory
def process_docs(directory, vocab, is_train):lines = list()# walk through all files in the folderfor filename in listdir(directory):# skip any reviews in the test setif is_train and filename.startswith('cv9'):continueif not is_train and not filename.startswith('cv9'):continue# create the full path of the file to openpath = directory + '/' + filename# load and clean the docline = doc_to_line(path, vocab)# add to listlines.append(line)return lines# load and clean a dataset
def load_clean_dataset(vocab, is_train):# load documentsneg = process_docs('./dataset/review_polarity/txt_sentoken/neg', vocab, is_train)pos = process_docs('./dataset/review_polarity/txt_sentoken/pos', vocab, is_train)docs = neg + pos# prepare labelslabels = [0 for _ in range(len(neg))] + [1 for _ in range(len(pos))]return docs, np.array(labels)# fit a tokenizer
def create_tokenizer(lines):tokenizer = Tokenizer()tokenizer.fit_on_texts(lines)return tokenizer# load the vocabulary
vocab_filename = 'vocab.txt'
vocab = load_doc(vocab_filename)
vocab = set(vocab.split())# load all reviews
train_docs, ytrain = load_clean_dataset(vocab, True)
test_docs, ytest = load_clean_dataset(vocab, False)# create the tokenizer
tokenizer = create_tokenizer(train_docs)# encode data
Xtrain = tokenizer.texts_to_matrix(train_docs, mode='freq')
Xtest = tokenizer.texts_to_matrix(test_docs, mode='freq')
print(Xtrain.shape, Xtest.shape)
print(type(Xtrain), type(ytrain))from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras.utils import plot_model, to_categorical# define the model
def define_model(n_words):# define networkmodel = Sequential()model.add(Dense(50, activation='relu', input_shape=(n_words,)))model.add(Dense(1, activation='sigmoid'))# compile networkmodel.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])# summarize defined modelmodel.summary()plot_model(model, to_file='model.png', show_shapes=True)return model# define the model
n_words = Xtest.shape[1]
model = define_model(n_words)
# fit network
model.fit(Xtrain, ytrain, epochs=10, verbose=2)
# evaluate
loss, acc = model.evaluate(Xtest, ytest, verbose=0)
print('Test Accuracy: %f' % (acc*100))
  相关解决方案