How much do you know about the world?Crazy. You understand that we live in a 3D environment, where objects move, people talk, and animals fly. There is a huge amount of data in the world, most of which are easily accessible-the difficult part is developing algorithms that can analyze and understand this rich data.Generating models is one of the most promising ways to achieve this.There are many short-term applications for generating models, but in the long run, they have the potential to learn the natural characteristics of data sets, whether they are categories, pixels, audio samples or completely different things.

Generation algorithm

You can group the generation algorithms into one of three buckets:

  1. Given the label, they predict related functions (Naïve Bayes)
  2. Given a hidden representation, they predict the relevant features (variational autoencoder, generate against the network)
  3. Given some features, they predict the rest (fix, interpolate)

In this article, we will exploreGenerate a confrontation network (GAN)ofSome basic knowledge! GAN has incredible potential because they can learn to mimic any data distribution. In other words, GAN can learn to create a world similar to our own in any field: images, music, and voice.

Example GAN architecture

Generate a confrontation network (GAN)

Build section

  • Calledgenerator.
  • Given a certainTAG, try to predictFEATURES.
  • EX: The text of the email is predicted (generated) in view of the email being marked as spam.
  • The generation model learns the distribution of each class.

"Defensive" section

  • CalledDiscriminator.
  • In view of theseFEATURES, try to predictTAG.
  • EX: Predict (distinguish) spam or non-spam based on the text of the email.
  • The discriminant model learns the boundaries between classes.

How does GAN work?

A neural network called Generator generates new data instances, while another neural network Discriminator evaluates their authenticity.

You can think of GAN as a cat and mouse game between a counterfeiter (generator) and a police (Discriminator). Counterfeiters are learning to make fake money, and the police are learning how to detect fake money. They are all learning and improving. Counterfeiters continue to learn to create better fakes, and the police continue to get better when testing them. The end result is that the counterfeiter (generator) is now trained to create surreal money!

Let's explore a concrete example with the MNIST handwritten digital data set:

MNIST handwritten digital data set

We will let the Generator create new images, such as images from the MNIST dataset, taken from the real world. When displaying instances from real MNIST datasets, Discriminator's goal is to identify them as real.

At the same time, the Generator is creating a new image that is passed to the Discriminator. It does this, hopefully they will also be considered true, even if they are fake. The goal of the Generator is to generate handwritten numbers that can be passed to lie without being captured. The goal of Discriminator is to classify images from the Generator as fake.

MNIST handwritten digit + GAN architecture

GAN steps:

  1. The generator receives the random number and returns the image.
  2. The generated image is fed into the discriminator along with the image stream acquired from the actual data set.
  3. The discriminator receives the real and false images and returns the probability, the number between 0 and 1, 1 for the prediction of authenticity, and 0 for false.

Two feedback loops:

  1. The discriminator is in the feedback loop and has the basic facts of the image (whether they are real or fake), we know.
  2. The generator and Discriminator are in the feedback loop (Discriminator marks it as true or forged, regardless of the fact).

Training GAN skills?

Pre-identifying the discriminator before starting the training generator will establish a sharper gradient.

When training Discriminator, keep the Generator value unchanged. Keep the Discriminator value unchanged while training the generator. This allows the network to better understand the gradients it must learn.

GAN is being developed as a game between two networks, important (and difficult!) to maintain their balance. If the generator or discriminator is too good, GAN may be difficult to learn.

GAN takes a long time to train. In a singleGPUOn top, GAN may take hours, and on a single CPU, GAN may take several days.

GAN code example

Sufficient words. The following is created by Erik LinderKeras implementationOfGANExample:

class GAN():
    def __init__(self):
        self.img_rows = 28 
        self.img_cols = 28
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)

        optimizer = Adam(0.0002, 0.5)

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()

        # Build and compile the generator
        self.generator = self.build_generator()
        self.generator.compile(loss='binary_crossentropy', optimizer=optimizer)

        # The generator takes noise as input and generated imgs
        z = Input(shape=(100,))
        img = self.generator(z)

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # The valid takes generated images as input and determines validity
        valid = self.discriminator(img)

        # The combined model  (stacked generator and discriminator) takes
        # noise as input => generates images => determines validity 
        self.combined = Model(z, valid)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

    def build_generator(self):

        noise_shape = (100,)
        model = Sequential()

        model.add(Dense(256, input_shape=noise_shape))
        model.add(Dense(, activation='tanh'))


        noise = Input(shape=noise_shape)
        img = model(noise)

        return Model(noise, img)

    def build_discriminator(self):

        img_shape = (self.img_rows, self.img_cols, self.channels)
        model = Sequential()

        model.add(Dense(1, activation='sigmoid'))

        img = Input(shape=img_shape)
        validity = model(img)

        return Model(img, validity)

    def train(self, epochs, batch_size=128, save_interval=50):

        # Load the dataset
        (X_train, _), (_, _) = mnist.load_data()

        # Rescale -1 to 1
        X_train = (X_train.astype(np.float32) - 127.5) / 127.5
        X_train = np.expand_dims(X_train, axis=3)

        half_batch = int(batch_size / 2)

        for epoch in range(epochs):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random half batch of images
            idx = np.random.randint(0, X_train.shape[0], half_batch)
            imgs = X_train[idx]

            noise = np.random.normal(0, 1, (half_batch, 100))

            # Generate a half batch of new images
            gen_imgs = self.generator.predict(noise)

            # Train the discriminator
            d_loss_real = self.discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, np.zeros((half_batch, 1)))
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ---------------------
            #  Train Generator
            # ---------------------

            noise = np.random.normal(0, 1, (batch_size, 100))

            # The generator wants the discriminator to label the generated samples
            # as valid (ones)
            valid_y = np.array([1] * batch_size)

            # Train the generator
            g_loss = self.combined.train_on_batch(noise, valid_y)

            # Plot the progress
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

            # If at save interval => save generated image samples
            if epoch % save_interval == 0:

    def save_imgs(self, epoch):
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, 100))
        gen_imgs = self.generator.predict(noise)

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                cnt += 1
        fig.savefig("gan/images/mnist_%d.png" % epoch)

if __name__ == '__main__':
    gan = GAN()
    gan.train(epochs=30000, batch_size=32, save_interval=200)

How to improve GAN?

GAN was just invented in 2014 – they are very new! GAN is a very promising family of generative models because, unlike other methods, they can generate very clean and clear images and learn the weights that contain valuable information about the underlying data. However, as mentioned above, it may be difficult to balance the Discriminator and Generator networks. There is a lot of work in progress to make GAN training more stable.

In addition to generating beautiful pictures, a semi-supervised learning method using GAN has also been developed, which involves the discriminator generating additional output indicative of the input label. This method can achieve cutting-edge results on the dataset using very few labeled examples. For example, on MNIST, with a fully connected neural network with only 10 labeled examples per class, an accuracy of 99.1% is achieved-this result is very close to the best known method of fully supervised using all 60,000 labeled examples result. This is very promising, as it can be very expensive to obtain marked examples in practice.

in conclusion

GAN is still so fresh – I am so happy to see where they have gone! Isn't it?

This article was transferred from awardsdatascience,Original address