Run a diffusion transformer on a free Google Colab account

Category: Image Generation
Topic: diffusion diffusion transformer

Published by Nicole on Nov 03, 2024 • 5 min read.

This post shows a memory-efficient approach to generating high-quality images using the PixArt-Σ model, a state-of-the-art diffusion transformer for ultra-high-resolution image synthesis. As deep learning models grow in complexity, memory management becomes a crucial aspect, especially when working with limited GPU resources.

In this example, you use advanced quantization techniques provided by the BitsAndBytesConfig configuration and the optimum.quanto library to reduce the memory footprint while maintaining performance. The notebook will guide you through the steps of setting up a quantized text encoder, generating prompt embeddings, and efficiently managing GPU memory during the image generation process.

You will also monitor GPU memory usage throughout the process, using PyTorch's memory functionality, and employ strategies like freezing parts of the model and cleaning up unused resources to further optimize memory consumption. By the end of this post, you will have a practical understanding of how to handle large-scale models on limited hardware, enabling you to generate high-quality images with reduced memory overhead.

The provided code is inspired by the examples in Hugging Face's quanto libary and Diffusers library.

I omit the installs and some small helper functions here, but you can find all in the Google Colab notebook

 

You start first with the GPU memory usage monitoring, to keep track of how much memory you model will use at different stages. This is essential for debugging and optimizing memory management.

torch.cuda.memory._record_memory_history()

Next, define the quantization configuration, loading the model in a 4-bit format for efficiency. This setting allows the model to use less memory while maintaining performance, ideal for large models on limited GPUs.

quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

Then, you initialize the T5EncoderModel using your quantization settings and balance its layers across devices for optimal performance.

text_encoder = T5EncoderModel.from_pretrained(
    "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
    subfolder="text_encoder",
    quantization_config=quant_config,
    device_map="balanced",
)

Load the PixArtSigmaPipeline next, incorporating your text encoder model. This pipeline configuration further supports efficient memory use by balancing across devices.

pipe = PixArtSigmaPipeline.from_pretrained(
    "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
    text_encoder=text_encoder,
    transformer=None,
    device_map="balanced"
)

Using torch.no_grad(), you encode a prompt into embeddings without storing gradients, which reduces memory consumption and speeds up processing.


with torch.no_grad():
    prompt = "Cute animated tabby with big eyes"
    prompt_embeds, prompt_attention_mask, negative_embeds, negative_prompt_attention_mask = pipe.encode_prompt(prompt)

Next, output the current GPU memory usage, both allocated and reserved, to understand the memory demands of our setup so far:

print(
    f"Max memory allocated: {to_giga_bytes(torch.cuda.max_memory_allocated())} GB"
)

print(
    f"Max memory reserved: {to_giga_bytes(torch.cuda.memory_reserved())} GB"
)

This gives: 


Max memory allocated: 6.249999046325684 GB

Max memory reserved: 6.587890625 GB

To conserve GPU memory, delete the text_encoder and pipe objects when they're no longer needed and clear the cache, allowing us to monitor how much memory is freed up.

del text_encoder
del pipe

gc.collect()
torch.cuda.empty_cache()

print(
    f"Max memory allocated: {to_giga_bytes(torch.cuda.max_memory_allocated())} GB"
)

print(
    f"Max memory reserved: {to_giga_bytes(torch.cuda.memory_reserved())} GB"
)

Next, reload the pipeline, this time using additional quantization (int8) for further memory savings. Freezing layers also prevents unnecessary gradient calculations during inference. 

pipe = PixArtSigmaPipeline.from_pretrained(
    "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
    text_encoder=None,
    torch_dtype=torch.float16,
).to("cuda")

quantize(pipe.transformer, weights=qint8, exclude="proj_out")
freeze(pipe.transformer)

Using the prompt embeddings, we generate latent representations of the image, which are then used in later stages of the pipeline.

latents = pipe(
    negative_prompt=None,
    prompt_embeds=prompt_embeds,
    negative_prompt_embeds=negative_embeds,
    prompt_attention_mask=prompt_attention_mask,
    negative_prompt_attention_mask=negative_prompt_attention_mask,
    num_images_per_prompt=1,
    output_type="latent",
).images

print(
    f"Max memory allocated: {to_giga_bytes(torch.cuda.max_memory_allocated())} GB"
)

print(
    f"Max memory reserved: {to_giga_bytes(torch.cuda.memory_reserved())} GB"
)

Once you’re done with the transformer, delete it to free GPU memory. This helps optimize memory for the final decoding step.


del pipe.transformer

gc.collect()
torch.cuda.empty_cache()

print(
    f"Max memory allocated: {to_giga_bytes(torch.cuda.max_memory_allocated())} GB"
)

print(
    f"Max memory reserved: {to_giga_bytes(torch.cuda.memory_reserved())} GB"
)

Now, decode the latents into an actual image and save it as tabby.png (name it accordingly to what you generate). 

with torch.no_grad():
    image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
image = pipe.image_processor.postprocess(image, output_type="pil")

image[0].save("tabby.png")

Here is the image my run generated: 

Finally, save a memory snapshot to a file and print a summary of GPU memory usage. This gives an overview of the memory footprint after processing, valuable for future debugging and optimization.

torch.cuda.memory._dump_snapshot("PixArtSigma_quant.pickle")

print(
    torch.cuda.memory_summary()
)

Download the created pickle file and upload it to the PyTorch visualizer to visualize your memory usage history. This will create a plot similar to the one here:

Which shows you the memory history for quantized PixArt-Σ, where the different stripes refer to different allocations from tensors, the smaller the stripes the smaller the tensors allocated.