from langchain import HuggingFacePipeline
from langchain import PromptTemplate, LLMChain
from transformers import AutoTokenizer, AutoModelForCausalLM

import transformers
import os
import torch
import gradio as gr

import subprocess

#command = 'pip install git+https://github.com/huggingface/transformers'
#subprocess.run(command, shell=True)

# check if cuda is available
torch.cuda.is_available()

# define the model id
# model_id = "tiiuae/falcon-40b-instruct"
model_id = "tiiuae/falcon-7b-instruct"

# load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)

# load the model
## params:
## cache_dir: Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. \n
## device_map: ensures the model is moved to your GPU(s)
cache_dir = "./workspace/"
torch_dtype = torch.bfloat16
trust_remote_code = True
device_map = "auto"
offload_folder = "offload"

model = AutoModelForCausalLM.from_pretrained(model_id, cache_dir=cache_dir, torch_dtype = torch_dtype,
                                             trust_remote_code=trust_remote_code, device_map=device_map, offload_folder=offload_folder)
# set pt model to inference mode
model.eval()

# build the hf transformers pipeline
task = "text-generation"
max_length = 400
do_sample = True
top_k = 10
num_return_sequences = 1
eos_token_id = tokenizer.eos_token_id

pipeline = transformers.pipeline("text-generation", model = model, tokenizer = tokenizer,
            device_map = device_map, max_length = max_length,
            do_sample = do_sample, top_k = top_k,
            num_return_sequences = num_return_sequences,
            eos_token_id = eos_token_id)
            
# setup promt template
template = PromptTemplate(input_variables = ['input'], template = '{input}')

# pass hf pipeline to langhcain class
llm = HuggingFacePipeline(pipeline=pipeline)

# build stacked llm chain, ie prompt-formatting + llm
chain = LLMChain(llm=llm, prompt=template)


# create generate function
def generate(prommpt):
  # the prompt will get passes to the llm chain
  return chain.run(prompt)
  # and will return responses

title = "Falcon 40-b-Instruct 🦅"
description = "Web app application using the open-source `Falcon-40b-Instruct` LLM"
  
# build gradio interface
gr.Interface(fn=generate,
             input=["text"],
             outputs=["text"],
             title=title,
             descrption=description).launch()