from keras import utils
from keras.datasets import cifar10
(x_train,y_train),(x_test,y_test) = cifar10.load_data()

nums_names = ('airplane','automobile','bird','cat','deer','dog','frog','horse','ship','chuck')
y_train = utils.np_utils.to_categorical(y_train,len(nums_names))

出现这样问题是因为keras 的版本过高

解决方法:

1.把keras 重装 把版本降低点

2.从tensorflow中调用to_categorical

from tensorflow.keras.utils import to_categorical
y_train = to_categorical(y_train,len(nums_names))
Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐