Text2Face-LoRa

Python version License

This is a LoRa-finetuned version of the Stable Diffusion 2.1 model specifically optimized for generating face images. The model was trained with FFHQ and easyportrait using synthetic text captions for both datasets. Details on the dataset format and preparation will be available soon.

Checkpoints

You can download the pretrained LoRa weights for the diffusion model and text encoder using

from huggingface_hub import hf_hub_download

hf_hub_download(repo_id="michaeltrs/text2face",
                    filename="checkpoints/lora30k/pytorch_lora_weights.safetensors",
                    local_dir="checkpoints")

Inference

Generate images using the generate.py script, which loads the SD2.1 foundation model from Hugging Face and applies the LoRa weights. Generation is driven by defining a prompt and optionally a negative prompt.

from diffusers import StableDiffusionPipeline
import torch


class Model:
    def __init__(self, checkpoint="checkpoints/lora30k", weight_name="pytorch_lora_weights.safetensors", device="cuda"):
        self.checkpoint = checkpoint
        state_dict, network_alphas = StableDiffusionPipeline.lora_state_dict(
            # Path to my trained lora output_dir
            checkpoint,
            weight_name=weight_name
        )
        self.pipe = StableDiffusionPipeline.from_pretrained(
            "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16).to(device)
        self.pipe.load_lora_into_unet(state_dict, network_alphas, self.pipe.unet, adapter_name='test_lora')
        self.pipe.load_lora_into_text_encoder(state_dict, network_alphas, self.pipe.text_encoder, adapter_name='test_lora')
        self.pipe.set_adapters(["test_lora"], adapter_weights=[1.0])


    def generate(self, prompt, negprompt='', steps=50, savedir=None, seed=1):
        lora_scale = 1.0
        image = self.pipe(prompt,
                     negative_prompt=negprompt,
                     num_inference_steps=steps,
                     cross_attention_kwargs={"scale": lora_scale},
                     generator=torch.manual_seed(seed)).images[0]
        if savedir is None:
            image.save(f"{self.checkpoint}/{'_'.join(prompt.replace('.', ' ').split(' '))}.png")
        else:
            image.save(f"{savedir}/{'_'.join(prompt.replace('.', ' ').split(' '))}.png")
        return image


if __name__ == "__main__":

    model = Model()

    prompt = 'A happy 55 year old male with blond hair and a goatee smiles with visible teeth.'
    negprompt = ''

    image = model.generate(prompt, negprompt=negprompt, steps=50, seed=42)

Limitations

This model, Text2Face-LoRa, is finetuned from Stable Diffusion 2.1 and as such, inherits all the limitations and biases associated with the base model. These biases may manifest in skewed representations across different ethnicities and genders due to the nature of the training data originally used for Stable Diffusion 2.1.

Specific Limitations Include:

  • Ethnic and Gender Biases: The model may generate images that do not equally represent the diversity of human features in different ethnic and gender groups, potentially reinforcing or exacerbating existing stereotypes.

  • Selection Bias in Finetuning Datasets: The datasets used for finetuning this model were selected with specific criteria in mind, which may not encompass a wide enough variety of data points to correct for the inherited biases of the base model.

  • Caption Generation Bias: The synthetic annotations used to finetune this model were generated by automated face analysis models, which themselves may be biased. This could lead to inaccuracies in facial feature interpretation and representation, particularly for less-represented demographics in the training data.

Ethical Considerations:

Users are encouraged to consider these limitations when deploying the model in real-world applications, especially those involving diverse human subjects. It is advisable to perform additional validations and seek ways to mitigate these biases in practical use cases.

Downloads last month
0
Inference API
Unable to determine this model’s pipeline type. Check the docs .