Generative Adversarial Networks Projects
上QQ阅读APP看书,第一时间看更新

Testing the models

To test the networks, create the generator and the discriminator networks. Then, load the learned weights. Finally, use the predict() method to generate predictions:

# Create models
generator = build_generator()
discriminator = build_discriminator()

# Load model weights
generator.load_weights(os.path.join(generated_volumes_dir, "generator_weights.h5"), True)
discriminator.load_weights(os.path.join(generated_volumes_dir, "discriminator_weights.h5"), True)

# Generate 3D images
z_sample = np.random.normal(0, 0.33, size=[batch_size, 1, 1, 1, z_size]).astype(np.float32)
generated_volumes = generator.predict(z_sample, verbose=3)

In this section, we have successfully trained the generator and the discriminator of the 3D-GAN. In the next section, we will explore hyperparameter tuning and various hyperparameter optimization options.