Keras 'Sequential' object has no attribute 'predict_classes' 错误解决
问题描述
在使用 Keras 进行多分类模型的性能评估时,您可能会遇到以下错误:
AttributeError: 'Sequential' object has no attribute 'predict_classes'
这个错误通常发生在使用 TensorFlow 2.6 或更高版本时,当尝试使用 model.predict_classes()
方法获取预测类别时会出现。虽然这段代码在 TensorFlow 早期版本中可以正常工作,但由于 API 更新,该方法已被弃用并最终移除。
解决方案
根据您使用的分类类型(多分类或二分类),有以下几种解决方案:
多分类问题(使用 softmax 激活函数)
对于多分类问题,替代 predict_classes()
的方法是:
# 获取预测概率
predictions = model.predict(X_test)
# 获取预测类别(取概率最大的类别)
y_pred_classes = np.argmax(predictions, axis=-1)
或者简写为:
y_pred_classes = np.argmax(model.predict(X_test), axis=-1)
二分类问题(使用 sigmoid 激活函数)
对于二分类问题,可以使用以下方法:
# 方法1:使用阈值0.5进行分类
y_pred_classes = (model.predict(X_test) > 0.5).astype("int32")
# 方法2:使用四舍五入
y_pred = model.predict(X_test)
y_pred_classes = np.round(y_pred).astype(int)
代码示例
以下是一个完整的多分类示例,展示如何正确进行预测:
import numpy as np
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
# 创建模型
model = Sequential()
model.add(Dense(24, input_dim=13, activation='relu'))
model.add(Dense(18, activation='relu'))
model.add(Dense(6, activation='softmax')) # 多分类,使用softmax
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
# 训练模型
history = model.fit(X_train, y_train, batch_size=256, epochs=10, verbose=2, validation_split=0.2)
# 评估模型
score, acc = model.evaluate(X_test, y_test, verbose=2, batch_size=256)
print('Test accuracy:', acc)
# 正确的预测方法(替代 predict_classes)
predictions = model.predict(X_test)
yhat_classes = np.argmax(predictions, axis=-1) # 获取预测类别
版本兼容性说明
WARNING
从 TensorFlow 2.6 开始,predict_classes()
方法已被完全移除。在 TensorFlow 2.5 中,该方法仍然可用但会显示弃用警告。
如果您使用的是 TensorFlow 2.5,您可能会看到以下警告:
UserWarning:
model.predict_classes()
is deprecated and will be removed after 2021-01-01. Please use instead:np.argmax(model.predict(x), axis=-1)
for multi-class classification, or(model.predict(x) > 0.5).astype("int32")
for binary classification.
总结
predict_classes()
方法已在 TensorFlow 2.6+ 中移除- 对于多分类问题,使用
np.argmax(model.predict(x), axis=-1)
- 对于二分类问题,使用
(model.predict(x) > 0.5).astype("int32")
- 建议更新代码以避免未来版本兼容性问题
通过使用这些替代方法,您可以继续在最新版本的 TensorFlow 中获取模型的预测结果,而不会遇到属性错误。