|
马上注册,结交更多好友,享用更多功能^_^
您需要 登录 才可以下载或查看,没有账号?立即注册
x
# Multiprocessing set up
num_workers = None if parallel else 1
pool = Pool(num_workers)
train_data, test_data = get_mnist_data_iters(
data_dir, train_size, test_size, full_test_set, seed=seed)
LOG.info("Training on {} images...".format(len(train_data)))
train_partial = partial(train_image,
perturb_factor=perturb_factor)
train_results = pool.map_async(train_partial, [d[0] for d in train_data]).get(9999999)
all_model_factors = zip(*train_results)
LOG.info("Testing on {} images...".format(len(test_data)))
test_partial = partial(test_image, model_factors=all_model_factors,
pool_shape=pool_shape)
test_results = pool.map_async(test_partial, [d[0] for d in test_data]).get(9999999) |
|