--- library_name: transformers license: apache-2.0 datasets: - StanfordAIMI/interpret-cxr-test-public - StanfordAIMI/interpret-cxr-test-hidden --- # CXRMate-RRG4: Entropy-Augmented Self-Critical Sequence Training for Radiology Report Generation This is an evolution of https://huggingface.co/aehrc/cxrmate developed for the Radiology Report Generation task of BioNLP @ ACL 2024. For this, we proposed EAST: Entropy-Augmented Self-critical sequence Training (EAST): - EAST modifies Self-Critical Sequence Training (SCST) by adding entropy regularisation. - Helps maintain a higher entropy in the token distribution. - Preventing overfitting to common phrases and ensuring a broader exploration of the vocabulary during training. - This was essential to handle the diversity of the radiology reports in the RRG24 datasets. EAST was applied to a multimodal language model with RadGraph as the reward. Other features include: - Token type embeddings to differentiate between findings and impression section tokens, as well as image embeddings. - Special tokens (`[NF]` and `[NI]`) to handle missing *findings* and *impression* sections. - Non-causal attention masking for the image embeddings and a causal attention masking for the report token embeddings. ## Example: ```python import torch from torchvision.transforms import v2 import transformers tokenizer = transformers.AutoTokenizer.from_pretrained('aehrc/cxrmate-rrg24') model = transformers.AutoModel.from_pretrained('aehrc/cxrmate-rrg24', trust_remote_code=True) transforms = v2.Compose( [ v2.PILToTensor(), v2.Grayscale(num_output_channels=3), v2.Resize(size=model.config.encoder.image_size, antialias=True), v2.CenterCrop(size=[model.config.encoder.image_size]*2), v2.ToDtype(torch.float32, scale=True), v2.Normalize(mean=model.config.encoder.image_mean, std=model.config.encoder.image_std), ] ) dataset = datasets.load_dataset('StanfordAIMI/interpret-cxr-test-public')['test'] def transform_batch(batch): batch['images'] = [torch.stack([transforms(j) for j in i]) for i in batch['images']] batch['images'] = torch.nn.utils.rnn.pad_sequence(batch['images'], batch_first=True, padding_value=0.0) return batch dataset = dataset.with_transform(transform_batch) dataloader = DataLoader(dataset, batch_size=mbatch_size, shuffle=True) batch = next(iter(dataloader)) output_ids = model.generate( pixel_values=batch['images'], max_length=512, num_beams=4, bad_words_ids=[[tokenizer.convert_tokens_to_ids('[NF]')], [tokenizer.convert_tokens_to_ids('[NI]')]], ) findings, impression = model.split_and_decode_sections(output_ids, tokenizer) ``` ## Generate findings only: ```python output_ids = model.generate( pixel_values=batch['images'], max_length=512, num_beams=4, bad_words_ids=[[tokenizer.convert_tokens_to_ids('[NF]')]], eos_token_id=tokenizer.sep_token_id, ) findings, _ = model.split_and_decode_sections(output_ids, tokenizer) ``` ## Generate impression only: ```python output_ids = model.generate( pixel_values=batch['images'], max_length=512, num_beams=4, bad_words_ids=[[tokenizer.convert_tokens_to_ids('[NI]')]], input_ids=torch.tensor([[tokenizer.bos_token_id, tokenizer.convert_tokens_to_ids('[NF]'), tokenizer.sep_token_id]]*mbatch_size, device=device, dtype=torch.long), ) _, impression = model.split_and_decode_sections(output_ids, tokenizer) ``` ## Notebook example: https://huggingface.co/aehrc/cxrmate-rrg24/blob/main/demo.ipynb ## Paper: ## Citation: [More Information Needed]