import tensorflow as tf import os import _pickle as cPickle import numpy as np
CIFAR_DIR = "./cifar-10-batches-py" print(os.listdir(CIFAR_DIR)) defload_data(filename): """read data from data file""" withopen(filename, 'rb') as f: data = cPickle.load(f, encoding="latin1") return data['data'], data['labels'] # tensorflow.Dataset
classCifarData: def__init__(self, filenames, need_shuffle): all_data = [] all_labels = [] for filename in filenames: data, labels = load_data(filename) all_data.append(data) all_labels.append(labels) self._data = np.vstack(all_data) self._data = self._data / 127.5 - 1 self._labels = np.hstack(all_labels) print(self._data.shape) print(self._labels.shape) self._num_examples = self._data.shape[0] self._need_shuffle = need_shuffle self._indicator = 0 ifself._need_shuffle: self._shuffle_data() def_shuffle_data(self): p = np.random.permutation(self._num_examples) self._data = self._data[p] self._labels = self._labels[p] defnext_batch(self, batch_size): """return batch_size example as a batch.""" end_indicator = self._indicator + batch_size if end_indicator > self._num_examples: ifself._need_shuffle: self._shuffle_data() self._indicator = 0 end_indicator = batch_size else: raise Exception("have no more examples") if end_indicator > self._num_examples: raise Exception("batch size is larger than all examples") batch_data = self._data[self._indicator: end_indicator] batch_labels = self._labels[self._indicator: end_indicator] self._indicator = end_indicator return batch_data, batch_labels
train_filenames = [os.path.join(CIFAR_DIR, 'data_batch_%d' % i) for i inrange(1, 6)] test_filenames = [os.path.join(CIFAR_DIR, 'test_batch')]