Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions examples/vae/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,16 @@ This code follows [pytorch/examples/vae](https://github.com/pytorch/examples/blo

```bash
pip install -r requirements.txt
python main.py --workdir=/tmp/mnist --config=configs/default.py
python main.py --workdir=/tmp/mnist
```

## Overriding Hyperparameter configurations
## Configuring hyperparameters

This VAE example allows specifying a hyperparameter configuration by the means of
setting `--config` flag. Configuration flag is defined using
[config_flags](https://github.com/google/ml_collections/tree/master#config-flags).
`config_flags` allows overriding configuration fields. This can be done as
follows:
The VAE example uses simple command line arguments for configuration. You can override the default values as follows:

```shell
python main.py \
--workdir=/tmp/mnist --config=configs/default.py \
--config.learning_rate=0.01 --config.num_epochs=10
--workdir=/tmp/mnist
```


Expand Down
14 changes: 14 additions & 0 deletions examples/vae/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""Simple configuration using dataclasses instead of ml_collections."""
from dataclasses import dataclass

@dataclass
class TrainingConfig:
"""Training configuration parameters."""
learning_rate: float = 0.001
latents: int = 20
batch_size: int = 128
num_epochs: int = 30

def get_default_config() -> TrainingConfig:
"""Get the default configuration."""
return TrainingConfig()
21 changes: 11 additions & 10 deletions examples/vae/configs/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@

"""Default Hyperparameter configuration."""

import ml_collections
from dataclasses import dataclass

@dataclass
class TrainingConfig:
"""Training configuration parameters."""
learning_rate: float = 0.001
latents: int = 20
batch_size: int = 128
num_epochs: int = 30

def get_config():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep this file and create training config using dataclass like in your examples/vae/config.py.
Finally, examples/vae/config.py can be removed. You can follow the same approach as here: https://github.com/google/flax/blob/main/examples/gemma/configs/default.py

"""Get the default hyperparameter configuration."""
config = ml_collections.ConfigDict()

config.learning_rate = 0.001
config.latents = 20
config.batch_size = 128
config.num_epochs = 30
return config
def get_default_config() -> TrainingConfig:
"""Get the default configuration."""
return TrainingConfig()
32 changes: 11 additions & 21 deletions examples/vae/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,45 +21,35 @@
from absl import app
from absl import flags
from absl import logging
from clu import platform
import jax
from ml_collections import config_flags
import tensorflow as tf

import time
import train

from configs.default import get_default_config

FLAGS = flags.FLAGS

flags.DEFINE_string('workdir', None, 'Directory to store model data.')
config_flags.DEFINE_config_file(
'config',
None,
'File path to the training hyperparameter configuration.',
lock_config=True,
)
flags.mark_flags_as_required(['config', 'workdir'])

flags.DEFINE_string('workdir', None, 'Directory to store logs and checkpoints.')

def main(argv):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')

# Parse arguments and get config
config = get_default_config()

# Make sure tf does not allocate gpu memory.
tf.config.experimental.set_visible_devices([], 'GPU')

logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count())
logging.info('JAX local devices: %r', jax.local_devices())

# Add a note so that we can tell which task is which JAX host.
# (Depending on the platform task 0 is not guaranteed to be host 0)
platform.work_unit().set_task_status(
f'process_index: {jax.process_index()}, '
f'process_count: {jax.process_count()}'
)

train.train_and_evaluate(FLAGS.config, FLAGS.workdir)
# Simple process logging
logging.info('Starting training process %d/%d', jax.process_index(), jax.process_count())

start = time.perf_counter()
train.train_and_evaluate(config)
logging.info('Total training time: %.2f seconds', time.perf_counter() - start)

if __name__ == '__main__':
app.run(main)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sanepunk why do you remove abseil app and the usage of config file?

47 changes: 25 additions & 22 deletions examples/vae/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,44 +14,47 @@

"""VAE model definitions."""

from flax import linen as nn
from flax import nnx
from jax import random
import jax.numpy as jnp


class Encoder(nn.Module):
class Encoder(nnx.Module):
"""VAE Encoder."""

latents: int
def __init__(self, input_features: int, latents: int, *, rngs: nnx.Rngs):
self.fc1 = nnx.Linear(input_features, 500, rngs=rngs)
self.fc2_mean = nnx.Linear(500, latents, rngs=rngs)
self.fc2_logvar = nnx.Linear(500, latents, rngs=rngs)

@nn.compact
def __call__(self, x):
x = nn.Dense(500, name='fc1')(x)
x = nn.relu(x)
mean_x = nn.Dense(self.latents, name='fc2_mean')(x)
logvar_x = nn.Dense(self.latents, name='fc2_logvar')(x)
x = self.fc1(x)
x = nnx.relu(x)
mean_x = self.fc2_mean(x)
logvar_x = self.fc2_logvar(x)
return mean_x, logvar_x


class Decoder(nn.Module):
class Decoder(nnx.Module):
"""VAE Decoder."""

@nn.compact
def __init__(self, latents: int, output_features: int, *, rngs: nnx.Rngs):
self.fc1 = nnx.Linear(latents, 500, rngs=rngs)
self.fc2 = nnx.Linear(500, output_features, rngs=rngs)

def __call__(self, z):
z = nn.Dense(500, name='fc1')(z)
z = nn.relu(z)
z = nn.Dense(784, name='fc2')(z)
z = self.fc1(z)
z = nnx.relu(z)
z = self.fc2(z)
return z


class VAE(nn.Module):
class VAE(nnx.Module):
"""Full VAE model."""

latents: int = 20

def setup(self):
self.encoder = Encoder(self.latents)
self.decoder = Decoder()
def __init__(self, input_features: int, latents: int, rngs: nnx.Rngs):
self.encoder = Encoder(input_features=input_features, latents=latents, rngs=rngs)
self.decoder = Decoder(latents=latents, output_features=input_features, rngs=rngs)

def __call__(self, x, z_rng):
mean, logvar = self.encoder(x)
Expand All @@ -60,7 +63,7 @@ def __call__(self, x, z_rng):
return recon_x, mean, logvar

def generate(self, z):
return nn.sigmoid(self.decoder(z))
return nnx.sigmoid(self.decoder(z))


def reparameterize(rng, mean, logvar):
Expand All @@ -69,5 +72,5 @@ def reparameterize(rng, mean, logvar):
return mean + eps * std


def model(latents):
return VAE(latents=latents)
def model(input_features: int, latents: int, rngs: nnx.Rngs):
return VAE(input_features=input_features, latents=latents, rngs=rngs)
6 changes: 3 additions & 3 deletions examples/vae/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
absl-py==1.4.0
flax==0.6.9
numpy==1.23.5
flax~=0.12
numpy>=1.26.4
optax==0.1.5
Pillow==10.2.0
tensorflow==2.12.0
tensorflow-cpu~=2.18.0
tensorflow-datasets==4.9.2
78 changes: 31 additions & 47 deletions examples/vae/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Training and evaluation logic."""
from typing import Any

from absl import logging
from flax import linen as nn
from flax import nnx
import input_pipeline
import models
import utils as vae_utils
from flax.training import train_state
import jax
from jax import random
import jax.numpy as jnp
import ml_collections
from configs.default import TrainingConfig
import optax
import tensorflow as tf
import tensorflow_datasets as tfds


Expand All @@ -36,7 +33,7 @@ def kl_divergence(mean, logvar):

@jax.vmap
def binary_cross_entropy_with_logits(logits, labels):
logits = nn.log_sigmoid(logits)
logits = nnx.log_sigmoid(logits)
return -jnp.sum(
labels * logits + (1.0 - labels) * jnp.log(-jnp.expm1(logits))
)
Expand All @@ -47,44 +44,37 @@ def compute_metrics(recon_x, x, mean, logvar):
kld_loss = kl_divergence(mean, logvar).mean()
return {'bce': bce_loss, 'kld': kld_loss, 'loss': bce_loss + kld_loss}


def train_step(state, batch, z_rng, latents):
"""Train step."""
def loss_fn(params):
recon_x, mean, logvar = models.model(latents).apply(
{'params': params}, batch, z_rng
)

@nnx.jit
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use donate args to donate model and optimizer to reduce GPU memory usage.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried adding donate_argnums to nnx.jit in the train_step, but was getting NaN loss and kl divergence.
what to do?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What to do about this?

def train_step(model: nnx.Module, optimizer: nnx.Optimizer, batch, z_rng):
"""Single training step for the VAE model."""
def loss_fn(model):
recon_x, mean, logvar = model(batch, z_rng)
bce_loss = binary_cross_entropy_with_logits(recon_x, batch).mean()
kld_loss = kl_divergence(mean, logvar).mean()
loss = bce_loss + kld_loss
return loss

grads = jax.grad(loss_fn)(state.params)
return state.apply_gradients(grads=grads)


def eval_f(params, images, z, z_rng, latents):
"""Evaluation function."""
def eval_model(vae):
recon_images, mean, logvar = vae(images, z_rng)
comparison = jnp.concatenate([
images[:8].reshape(-1, 28, 28, 1),
recon_images[:8].reshape(-1, 28, 28, 1),
])
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(model, grads)
return loss

generate_images = vae.generate(z)
generate_images = generate_images.reshape(-1, 28, 28, 1)
metrics = compute_metrics(recon_images, images, mean, logvar)
return metrics, comparison, generate_images
@nnx.jit
def eval_f(model: nnx.Module, images, z, z_rng):
"""Evaluation function for the VAE model."""
recon_images, mean, logvar = model(images, z_rng)
comparison = jnp.concatenate([
images[:8].reshape(-1, 28, 28, 1),
recon_images[:8].reshape(-1, 28, 28, 1),
])
generate_images = model.generate(z)
generate_images = generate_images.reshape(-1, 28, 28, 1)
metrics = compute_metrics(recon_images, images, mean, logvar)
return metrics, comparison, generate_images

return nn.apply(eval_model, models.model(latents))({'params': params})


def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
def train_and_evaluate(config: TrainingConfig):
"""Train and evaulate pipeline."""
tf.io.gfile.makedirs(workdir)

rng = random.key(0)
rng, key = random.split(rng)

Expand All @@ -96,14 +86,9 @@ def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
test_ds = input_pipeline.build_test_set(ds_builder)

logging.info('Initializing model.')
init_data = jnp.ones((config.batch_size, 784), jnp.float32)
params = models.model(config.latents).init(key, init_data, rng)['params']

state = train_state.TrainState.create(
apply_fn=models.model(config.latents).apply,
params=params,
tx=optax.adam(config.learning_rate),
)
rngs = nnx.Rngs(0)
model = models.model(784, config.latents, rngs=rngs)
optimizer = nnx.Optimizer(model, optax.adam(config.learning_rate), wrt=nnx.Param)

rng, z_key, eval_rng = random.split(rng, 3)
z = random.normal(z_key, (64, config.latents))
Expand All @@ -116,16 +101,15 @@ def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
for _ in range(steps_per_epoch):
batch = next(train_ds)
rng, key = random.split(rng)
state = train_step(state, batch, key, config.latents)
loss_val = train_step(model, optimizer, batch, key)

metrics, comparison, sample = eval_f(
state.params, test_ds, z, eval_rng, config.latents
)
model, test_ds, z, eval_rng)
vae_utils.save_image(
comparison, f'{workdir}/reconstruction_{epoch}.png', nrow=8
comparison, f'results/reconstruction_{epoch}.png', nrow=8
)
vae_utils.save_image(
sample, f'{workdir}/sample_{epoch}.png', nrow=8
sample, f'results/sample_{epoch}.png', nrow=8
)

print(
Expand Down