Skip to content

Reproducing results on STL10. #52

@CS-GangXu

Description

@CS-GangXu

Hi there, I'm trying to reproducing the experimental results (IS score and FID) on STL10 dataset.

As described in the paper, the image resolution on STL10 dataset is 48x48.
However, in this repo you didn't provide the definition of 48x48 (both D and G).
So I create my own D and G at 48x48 resolution as bellow:

  • D
import chainer
from chainer import functions as F
from source.links.sn_embed_id import SNEmbedID
from source.links.sn_linear import SNLinear
from dis_models.resblocks import Block, OptimizedBlock

class SNResNetProjectionDiscriminator(chainer.Chain):
    def __init__(self, ch=64, n_classes=0, activation=F.relu):
        super(SNResNetProjectionDiscriminator, self).__init__()
        self.activation = activation
        with self.init_scope():
            self.block1 = OptimizedBlock(3, ch)
            self.block2 = Block(ch, 2*ch, activation=activation, downsample=True)
            self.block3 = Block(2*ch, 4*ch, activation=activation, downsample=True)
            self.block4 = Block(4*ch, 8*ch, activation=activation, downsample=True)
            self.block5 = Block(8*ch, 16*ch, activation=activation, downsample=False)
            self.l5 = SNLinear(16*ch, 1, initialW=chainer.initializers.GlorotUniform(), nobias=True)
            if n_classes > 0:
                self.l_y = SNEmbedID(n_classes, ch, initialW=chainer.initializers.GlorotUniform())

    def __call__(self, x, y=None):
        h = x
        h = self.block1(h)
        h = self.block2(h)
        h = self.block3(h)
        h = self.block4(h)
        h = self.block5(h)
        h = self.activation(h)
        # Global average pooling
        h = F.sum(h, axis=(2, 3))
        output = self.l5(h)
        if y is not None:
            w_y = self.l_y(y)
            output += F.sum(w_y * h, axis=1, keepdims=True)
        return output
  • G
import chainer
import chainer.links as L
from chainer import functions as F
from gen_models.resblocks import Block
from source.miscs.random_samples import sample_categorical, sample_continuous

class ResNetGenerator(chainer.Chain):
    def __init__(self, ch=512, dim_z=128, bottom_width=6, activation=F.relu, n_classes=0, distribution="normal"):
        super(ResNetGenerator, self).__init__()
        self.bottom_width = bottom_width
        self.activation = activation
        self.distribution = distribution
        self.dim_z = dim_z
        self.n_classes = n_classes
        with self.init_scope():
            self.l1 = L.Linear(dim_z, (bottom_width ** 2) * ch, initialW=chainer.initializers.GlorotUniform())
            self.block2 = Block(ch, int(ch / 2), activation=activation, upsample=True, n_classes=n_classes)
            self.block3 = Block(int(ch / 2), int(ch /4), activation=activation, upsample=True, n_classes=n_classes)
            self.block4 = Block(int(ch /4), int(ch / 8), activation=activation, upsample=True, n_classes=n_classes)
            self.b5 = L.BatchNormalization(int(ch / 8))
            self.c5 = L.Convolution2D(int(ch / 8), 3, ksize=3, stride=1, pad=1, initialW=chainer.initializers.GlorotUniform())

    def sample_z(self, batchsize=64):
        return sample_continuous(self.dim_z, batchsize, distribution=self.distribution, xp=self.xp)

    def sample_y(self, batchsize=64):
        return sample_categorical(self.n_classes, batchsize, distribution="uniform", xp=self.xp)

    def __call__(self, batchsize=64, z=None, y=None):
        if z is None:
            z = sample_continuous(self.dim_z, batchsize, distribution=self.distribution, xp=self.xp)
        if y is None:
            y = sample_categorical(self.n_classes, batchsize, distribution="uniform", xp=self.xp) if self.n_classes > 0 else None
        if (y is not None) and z.shape[0] != y.shape[0]:
            raise ValueError('z.shape[0] != y.shape[0]')
        h = z
        h = self.l1(h)
        h = F.reshape(h, (h.shape[0], -1, self.bottom_width, self.bottom_width))
        h = self.block2(h, y)
        h = self.block3(h, y)
        h = self.block4(h, y)
        h = self.b5(h)
        h = self.activation(h)
        h = F.tanh(self.c5(h))
        return h

And other hyper-parameters we strictly follow the original paper. The yaml configuration is:

batchsize: 128
iteration: 250000
iteration_decay_start: 0
seed: 0
display_interval: 100
progressbar_interval: 100
snapshot_interval: 10000
evaluation_interval: 1000

models:
  generator:
    fn: gen_models/resnet_48.py
    name: ResNetGenerator
    args:
      dim_z: 128
      bottom_width: 6
      ch: 512
      n_classes: 0

  discriminator:
      fn: dis_models/snresnet_48.py
      name: SNResNetProjectionDiscriminator
      args:
        ch: 64
        n_classes: 0

dataset:
  dataset_fn: datasets/stl10.py
  dataset_name: CIFAR10Dataset
  args:
    test: False

adam:
  alpha: 0.0002
  beta1: 0.0
  beta2: 0.9

updater:
  fn: updater.py
  name: Updater
  args:
    n_dis: 5
    n_gen_samples: 16
    conditional: False
    loss_type: hinge

I trained my model with a 2-gpu work-station, each GPU handles 128 images per iteration.
So the effective batch-size is 128x2=256.
While in your paper the effective batch-size is 64x4.

With this configuration, the best IS I can got is 7.8, far away from 9.10 in your paper.

So could some give some tips how can I reproduce the 9.10 IS in the paper?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions