Safetensors
vmistral
custom_code
File size: 5,449 Bytes
e745606
 
 
 
 
 
 
 
51ef0f4
e745606
 
 
 
 
 
 
 
 
 
 
 
 
 
 
023f97e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e745606
 
023f97e
 
 
 
 
 
 
 
 
 
e745606
 
 
 
 
 
 
 
 
 
 
296f6de
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
---
license: bsd-3-clause-clear
---

# WAFFLE: Multi-Modal Model for Automated Front-End Development
We develope WAFFLE, a fine-tuning approach to train multi-modal LLM (MLLM) to generate HTML code from webpage screenshots or UI designs. WAFFLE uses a structure-aware attention mechanism to improve MLLMs' understanding of HTML's structure and a contrastive fine-tuning approach to align MLLMs' understanding of UI images and HTML code. Models fine-tuned with WAFFLE show up to 9.00 pp (percentage point) higher HTML match, 0.0982 higher CW-SSIM, 32.99 higher CLIP, and 27.12 pp higher LLEM on our new benchmark WebSight-Test and an existing benchmark Design2Code.

## Updates:
* 10/24/2024: Our preprint avaiable at: [arXiv](https://arxiv.org/abs/2410.18362), [huggingface](https://huggingface.co/papers/2410.18362)
* 10/24/2024: Our code (keep maintaining) avaiable at: [code](https://github.com/lt-asset/Waffle)
* 10/24/2024: Our fine-tuned Waffle_VLM_WebSight (7B), using DoRA, is released at: [lt-asset/Waffle_VLM_WebSight](https://huggingface.co/lt-asset/Waffle_VLM_WebSight)

## Dependency
- peft               0.11.1
- transformers       4.41.1
- pytorch       2.3.0
- selenium
- Python 3.10.14
- deepspeed          0.14.1
- datasets 2.19.1
- beautifulsoup4     4.12.3
- accelerate         0.30.1

## Quick Start
* Input UI design

Find a webpage screenshot, or UI design:

![test-495.png](examples/test-495.png)

* Run Waffle_VLM_WebSight
```python
import torch
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension
from transformers.image_transforms import resize, to_channel_dimension_format
from utils import TreeBuilder


def convert_to_rgb(image):
    if image.mode == "RGB":
        return image

    image_rgba = image.convert("RGBA")
    background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
    alpha_composite = Image.alpha_composite(background, image_rgba)
    alpha_composite = alpha_composite.convert("RGB")
    return alpha_composite


def inference_vlm_websight(image_path, html_path):
    
    def custom_transform(x):
        x = convert_to_rgb(x)
        x = to_numpy_array(x)
        x = resize(x, (960, 960), resample=PILImageResampling.BILINEAR)
        x = processor.image_processor.rescale(x, scale=1 / 255)
        x = processor.image_processor.normalize(
            x,
            mean=processor.image_processor.image_mean,
            std=processor.image_processor.image_std
        )
        x = to_channel_dimension_format(x, ChannelDimension.FIRST)
        x = torch.tensor(x)
        return x

    model_dir = "lt-asset/Waffle_VLM_WebSight"
    processor = AutoProcessor.from_pretrained(model_dir)
    model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.bfloat16, trust_remote_code=True).cuda()
    
    assert model.config.web_attention_range == 2, "Waffle_VLM_WebSight is trained with hierarchical attention applied to 2 / 8 heads"
    # use 2/8 = 1/4 attention heads for hierarchical attention (as described in paper)
    model.eval()

    image_seq_len = model.config.perceiver_config.resampler_n_latents
    BOS_TOKEN = processor.tokenizer.bos_token
    BAD_WORDS_IDS = processor.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids

    image = Image.open(image_path)
    inputs = processor.tokenizer(
        f"{BOS_TOKEN}<fake_token_around_image>{'<image>' * image_seq_len}<fake_token_around_image>",
        return_tensors="pt",
        add_special_tokens=False,
    )
    inputs["pixel_values"] = processor.image_processor([image], transform=custom_transform).to(dtype=torch.bfloat16)
    inputs_for_generation = {k: v.cuda() for k, v in inputs.items()}
    inputs_for_generation["web_attention_mask"] = None
    inputs_for_generation["html_tree"] = TreeBuilder(processor.tokenizer)
    inputs_for_generation["html_tree"].web_attention_mask = inputs_for_generation["web_attention_mask"]

    generated_ids = model.generate(
        **inputs_for_generation, bad_words_ids=BAD_WORDS_IDS, max_length=2048, 
        num_return_sequences=1, do_sample=False
    )
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    
    with open(html_path, 'w') as wp:
        wp.write(generated_text)


if __name__ == '__main__':
    inference_vlm_websight('examples/test-495.png', 'examples/example-495.html')
```

* Waffle_VLM_WebSight generated HTML code

[example-495.html](examples/example-495.html)

* Rendered Waffle_VLM_WebSight output

Render the HTML, or preview the HTML to check the correctness:

![example-495.html](examples/example-495.png)

## Citation
```
@misc{liang2024wafflemultimodalmodelautomated,
      title={WAFFLE: Multi-Modal Model for Automated Front-End Development}, 
      author={Shanchao Liang and Nan Jiang and Shangshu Qian and Lin Tan},
      year={2024},
      eprint={2410.18362},
      archivePrefix={arXiv},
      primaryClass={cs.SE},
      url={https://arxiv.org/abs/2410.18362}, 
}
```

## License
The model is built on top of [VLM_WebSight_finetuned](https://huggingface.co/HuggingFaceM4/VLM_WebSight_finetuned). As such, users should comply with the licenses of these models.

The DoRA weights we trained are integrated with the original model's weights to produce the final model. We release the final model's weights under the Apache-2.0 license.