How did you apply bi-directional attention for embedding only?

#9
by Mayfull - opened

Hello. Thanks for your great work and helpful code in your github repo! It helps me a lot.
I am wondering how you implemented bi-directional attention only for embedding while maintaining causal attention for generation simultaneously.
What I understood so far from your code (https://github.com/ContextualAI/gritlm/blob/main/gritlm/training/model.py) is as follow.

I get that you put 'is_causal' key into the "kwargs" variable(Dict) and pass it to model call if self.attn[:2] == 'bb' (default to 'bbcc').
But I can't understand how this can remove causal mask even if there is no argument named 'is_causal' in forward function in most model (ex. MistralForCausalLM).

Would you mind giving me slight advice or information how it works?
Thank you so much.

        kwargs = {'input_ids': features.get('input_ids'), 'attention_mask': attention_mask}

        if self.attn[:2] == 'cb':
            kwargs['instruction_lens'] = instruction_lens
        elif self.attn[:2] == 'bb':
            kwargs['is_causal'] = False
        out = (getattr(self.model, self.embedding_attr) if self.embedding_attr else self.model)(**kwargs)[0]

The repo has a custom modeling file that has the is_causal kwarg: https://huggingface.co/GritLM/GritLM-7B/blob/main/modeling_gritlm7b.py

Thank you so much! Now I understood!

Sign up or log in to comment