|
--- |
|
library_name: transformers |
|
license: apache-2.0 |
|
datasets: |
|
- StanfordAIMI/interpret-cxr-test-public |
|
- StanfordAIMI/interpret-cxr-test-hidden |
|
--- |
|
|
|
# CXRMate-RRG4: Entropy-Augmented Self-Critical Sequence Training for Radiology Report Generation |
|
|
|
This is an evolution of https://huggingface.co/aehrc/cxrmate developed for the Radiology Report Generation task of BioNLP @ ACL 2024. |
|
|
|
For this, we proposed EAST: Entropy-Augmented Self-critical sequence Training (EAST): |
|
- EAST modifies Self-Critical Sequence Training (SCST) by adding entropy regularisation. |
|
- Helps maintain a higher entropy in the token distribution. |
|
- Preventing overfitting to common phrases and ensuring a broader exploration of the vocabulary during training. |
|
- This was essential to handle the diversity of the radiology reports in the RRG24 datasets. |
|
|
|
EAST was applied to a multimodal language model with RadGraph as the reward. Other features include: |
|
- Token type embeddings to differentiate between findings and impression section tokens, as well as image embeddings. |
|
- Special tokens (`NF` and `NI`) to handle missing *findings* and *impression* sections. |
|
- Non-causal attention masking for the image embeddings and a causal attention masking for the report token embeddings. |
|
|
|
## How to use: |
|
|
|
```python |
|
import torch |
|
from torchvision.transforms import v2 |
|
import transformers |
|
|
|
|
|
tokenizer = transformers.AutoTokenizer.from_pretrained('aehrc/cxrmate-rrg24') |
|
model = transformers.AutoModel.from_pretrained('aehrc/cxrmate-rrg24', trust_remote_code=True) |
|
transforms = v2.Compose( |
|
[ |
|
v2.PILToTensor(), |
|
v2.Grayscale(num_output_channels=3), |
|
v2.Resize(size=model.config.encoder.image_size, antialias=True), |
|
v2.CenterCrop(size=[model.config.encoder.image_size]*2), |
|
v2.ToDtype(torch.float32, scale=True), |
|
v2.Normalize(mean=model.config.encoder.image_mean, std=model.config.encoder.image_std), |
|
] |
|
) |
|
|
|
image = transforms(image) # Fix. |
|
|
|
output_ids = model.generate( |
|
pixel_values=images, # Fix. |
|
max_length=512, |
|
bad_words_ids=[[tokenizer.convert_tokens_to_ids('[NF]')], [tokenizer.convert_tokens_to_ids('[NI]')]], |
|
num_beams=4, |
|
use_cache=True, |
|
) |
|
findings, impression = model.split_and_decode_sections(output_ids, tokenizer) |
|
``` |
|
|
|
## Notebook example: |
|
https://huggingface.co/aehrc/cxrmate-rrg24/blob/main/demo.ipynb |
|
|
|
## Paper: |
|
|
|
## Citation: |
|
|
|
[More Information Needed] |
|
|
|
|