MadsGalsgaard commited on
Commit
7daf88a
·
verified ·
1 Parent(s): 9ae8177

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +182 -96
app.py CHANGED
@@ -441,80 +441,159 @@
441
 
442
  ###########new clientkey
443
 
 
 
 
 
 
 
444
  # import gradio as gr
445
- # from huggingface_hub import InferenceClient
446
 
447
- # # Hugging Face Inference Client setup
448
- # client = InferenceClient(
449
- # model="meta-llama/Meta-Llama-3.1-8B-Instruct" # Replace with your actual token
450
- # )
451
 
452
- # # Function to interact with the Hugging Face model
453
- # def chat_with_model(message, history):
454
- # # Prepare conversation history for the model
455
- # conversation = [{"role": "system", "content": "You are a helpful assistant."}]
456
-
457
- # for past_message, past_response in history:
458
- # conversation.append({"role": "user", "content": past_message})
459
- # conversation.append({"role": "assistant", "content": past_response})
460
-
461
- # # Add new user message to the conversation
462
- # conversation.append({"role": "user", "content": message})
463
-
464
- # # Generate response using the Inference API
465
- # responses = client.chat_completion(
466
- # messages=conversation,
467
- # max_tokens=500,
468
- # stream=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
  # )
470
-
471
- # # Capture streamed response
472
- # response_text = ""
473
- # for response in responses:
474
- # delta_content = response.choices[0].delta.content
475
- # response_text += delta_content
476
 
477
- # history.append((message, response_text))
478
-
479
- # return history, history # Update both chatbot history and visible chat
480
 
481
- # # Create Gradio interface
482
- # with gr.Blocks() as demo:
483
- # chatbot = gr.Chatbot(height=600)
484
- # msg_input = gr.Textbox(show_label=False, placeholder="Type your message...")
485
-
486
- # with gr.Row():
487
- # clear_btn = gr.Button("Clear Chat")
488
-
489
- # # Setting up interaction between user input and the chatbot
490
- # msg_input.submit(chat_with_model, [msg_input, chatbot], [chatbot, chatbot])
491
- # clear_btn.click(lambda: None, None, chatbot, queue=False)
492
-
493
- # gr.Markdown("## Llama 3.1 Chatbot")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
494
 
495
- # # Launch Gradio demo
496
  # if __name__ == "__main__":
497
  # demo.launch()
498
 
499
-
500
- import os
501
- import time
502
- import spaces
503
  import torch
504
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
505
  import gradio as gr
506
  from threading import Thread
507
 
 
508
  MODEL = "THUDM/LongWriter-llama3.1-8b"
509
-
510
  TITLE = "<h1><center>AreaX LLC-llama3.1-8b</center></h1>"
511
-
512
  PLACEHOLDER = """
513
  <center>
514
  <p>Hi! I'm AreaX AI Agent, capable of generating 10,000+ words. How can I assist you today?</p>
515
  </center>
516
  """
517
-
518
  CSS = """
519
  .duplicate-button {
520
  margin: auto !important;
@@ -527,54 +606,61 @@ h3 {
527
  }
528
  """
529
 
 
530
  device = "cuda" if torch.cuda.is_available() else "cpu"
531
 
 
532
  tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)
533
- model = AutoModelForCausalLM.from_pretrained(MODEL, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto")
534
- model = model.eval()
535
 
536
- @spaces.GPU()
537
  def stream_chat(
538
  message: str,
539
  history: list,
540
  system_prompt: str,
541
  temperature: float = 0.5,
542
- max_new_tokens: int = 32768,
543
  top_p: float = 1.0,
544
  top_k: int = 50,
545
  ):
546
- print(f'message: {message}')
547
- print(f'history: {history}')
548
-
549
- full_prompt = f"<<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
550
- for prompt, answer in history:
551
- full_prompt += f"[INST]{prompt}[/INST]{answer}"
552
- full_prompt += f"[INST]{message}[/INST]"
553
-
554
- inputs = tokenizer(full_prompt, truncation=False, return_tensors="pt").to(device)
555
- context_length = inputs.input_ids.shape[-1]
556
-
557
- streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
558
-
559
- generate_kwargs = dict(
560
- inputs=inputs.input_ids,
561
- max_new_tokens=max_new_tokens,
562
- do_sample=True,
563
- top_p=top_p,
564
- top_k=top_k,
565
- temperature=temperature,
566
- num_beams=1,
567
- streamer=streamer,
568
- )
569
-
570
- thread = Thread(target=model.generate, kwargs=generate_kwargs)
571
- thread.start()
572
-
573
- buffer = ""
574
- for new_text in streamer:
575
- buffer += new_text
576
- yield buffer
577
-
 
 
 
 
 
 
 
578
  chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
579
 
580
  with gr.Blocks(css=CSS, theme="soft") as demo:
@@ -601,9 +687,9 @@ with gr.Blocks(css=CSS, theme="soft") as demo:
601
  ),
602
  gr.Slider(
603
  minimum=1024,
604
- maximum=32768,
605
  step=1024,
606
- value=32768,
607
  label="Max new tokens",
608
  render=False,
609
  ),
@@ -624,12 +710,12 @@ with gr.Blocks(css=CSS, theme="soft") as demo:
624
  render=False,
625
  ),
626
  ],
627
- examples=[
628
- ["Write a 5000-word comprehensive guide on machine learning for beginners."],
629
- ["Create a detailed 3000-word business plan for a sustainable energy startup."],
630
- ["Compose a 2000-word short story set in a futuristic underwater city."],
631
- ["Develop a 4000-word research proposal on the potential effects of climate change on global food security."],
632
- ],
633
  cache_examples=False,
634
  )
635
 
 
441
 
442
  ###########new clientkey
443
 
444
+
445
+ # import os
446
+ # import time
447
+ # import spaces
448
+ # import torch
449
+ # from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
450
  # import gradio as gr
451
+ # from threading import Thread
452
 
453
+ # MODEL = "THUDM/LongWriter-llama3.1-8b"
 
 
 
454
 
455
+ # TITLE = "<h1><center>AreaX LLC-llama3.1-8b</center></h1>"
456
+
457
+ # PLACEHOLDER = """
458
+ # <center>
459
+ # <p>Hi! I'm AreaX AI Agent, capable of generating 10,000+ words. How can I assist you today?</p>
460
+ # </center>
461
+ # """
462
+
463
+ # CSS = """
464
+ # .duplicate-button {
465
+ # margin: auto !important;
466
+ # color: white !important;
467
+ # background: black !important;
468
+ # border-radius: 100vh !important;
469
+ # }
470
+ # h3 {
471
+ # text-align: center;
472
+ # }
473
+ # """
474
+
475
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
476
+
477
+ # tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)
478
+ # model = AutoModelForCausalLM.from_pretrained(MODEL, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto")
479
+ # model = model.eval()
480
+
481
+ # @spaces.GPU()
482
+ # def stream_chat(
483
+ # message: str,
484
+ # history: list,
485
+ # system_prompt: str,
486
+ # temperature: float = 0.5,
487
+ # max_new_tokens: int = 32768,
488
+ # top_p: float = 1.0,
489
+ # top_k: int = 50,
490
+ # ):
491
+ # print(f'message: {message}')
492
+ # print(f'history: {history}')
493
+
494
+ # full_prompt = f"<<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
495
+ # for prompt, answer in history:
496
+ # full_prompt += f"[INST]{prompt}[/INST]{answer}"
497
+ # full_prompt += f"[INST]{message}[/INST]"
498
+
499
+ # inputs = tokenizer(full_prompt, truncation=False, return_tensors="pt").to(device)
500
+ # context_length = inputs.input_ids.shape[-1]
501
+
502
+ # streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
503
+
504
+ # generate_kwargs = dict(
505
+ # inputs=inputs.input_ids,
506
+ # max_new_tokens=max_new_tokens,
507
+ # do_sample=True,
508
+ # top_p=top_p,
509
+ # top_k=top_k,
510
+ # temperature=temperature,
511
+ # num_beams=1,
512
+ # streamer=streamer,
513
  # )
 
 
 
 
 
 
514
 
515
+ # thread = Thread(target=model.generate, kwargs=generate_kwargs)
516
+ # thread.start()
 
517
 
518
+ # buffer = ""
519
+ # for new_text in streamer:
520
+ # buffer += new_text
521
+ # yield buffer
522
+
523
+ # chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
524
+
525
+ # with gr.Blocks(css=CSS, theme="soft") as demo:
526
+ # gr.HTML(TITLE)
527
+ # gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
528
+ # gr.ChatInterface(
529
+ # fn=stream_chat,
530
+ # chatbot=chatbot,
531
+ # fill_height=True,
532
+ # additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
533
+ # additional_inputs=[
534
+ # gr.Textbox(
535
+ # value="You are a helpful assistant capable of generating long-form content.",
536
+ # label="System Prompt",
537
+ # render=False,
538
+ # ),
539
+ # gr.Slider(
540
+ # minimum=0,
541
+ # maximum=1,
542
+ # step=0.1,
543
+ # value=0.5,
544
+ # label="Temperature",
545
+ # render=False,
546
+ # ),
547
+ # gr.Slider(
548
+ # minimum=1024,
549
+ # maximum=32768,
550
+ # step=1024,
551
+ # value=32768,
552
+ # label="Max new tokens",
553
+ # render=False,
554
+ # ),
555
+ # gr.Slider(
556
+ # minimum=0.0,
557
+ # maximum=1.0,
558
+ # step=0.1,
559
+ # value=1.0,
560
+ # label="Top p",
561
+ # render=False,
562
+ # ),
563
+ # gr.Slider(
564
+ # minimum=1,
565
+ # maximum=100,
566
+ # step=1,
567
+ # value=50,
568
+ # label="Top k",
569
+ # render=False,
570
+ # ),
571
+ # ],
572
+ # examples=[
573
+ # ["Write a 5000-word comprehensive guide on machine learning for beginners."],
574
+ # ["Create a detailed 3000-word business plan for a sustainable energy startup."],
575
+ # ["Compose a 2000-word short story set in a futuristic underwater city."],
576
+ # ["Develop a 4000-word research proposal on the potential effects of climate change on global food security."],
577
+ # ],
578
+ # cache_examples=False,
579
+ # )
580
 
 
581
  # if __name__ == "__main__":
582
  # demo.launch()
583
 
 
 
 
 
584
  import torch
585
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
586
  import gradio as gr
587
  from threading import Thread
588
 
589
+ # Model and constants
590
  MODEL = "THUDM/LongWriter-llama3.1-8b"
 
591
  TITLE = "<h1><center>AreaX LLC-llama3.1-8b</center></h1>"
 
592
  PLACEHOLDER = """
593
  <center>
594
  <p>Hi! I'm AreaX AI Agent, capable of generating 10,000+ words. How can I assist you today?</p>
595
  </center>
596
  """
 
597
  CSS = """
598
  .duplicate-button {
599
  margin: auto !important;
 
606
  }
607
  """
608
 
609
+ # Check device
610
  device = "cuda" if torch.cuda.is_available() else "cpu"
611
 
612
+ # Load model and tokenizer
613
  tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)
614
+ model = AutoModelForCausalLM.from_pretrained(MODEL, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto").eval()
 
615
 
 
616
  def stream_chat(
617
  message: str,
618
  history: list,
619
  system_prompt: str,
620
  temperature: float = 0.5,
621
+ max_new_tokens: int = 4096, # Lowered max tokens for efficiency
622
  top_p: float = 1.0,
623
  top_k: int = 50,
624
  ):
625
+ try:
626
+ full_prompt = f"<<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
627
+ for prompt, answer in history:
628
+ full_prompt += f"[INST]{prompt}[/INST]{answer}"
629
+ full_prompt += f"[INST]{message}[/INST]"
630
+
631
+ # Tokenize input
632
+ inputs = tokenizer(full_prompt, truncation=True, max_length=2048, return_tensors="pt").to(device)
633
+ context_length = inputs.input_ids.shape[-1]
634
+
635
+ # Setup TextIteratorStreamer for streaming response
636
+ streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
637
+
638
+ # Generation parameters
639
+ generate_kwargs = dict(
640
+ inputs=inputs.input_ids,
641
+ max_new_tokens=max_new_tokens,
642
+ do_sample=True,
643
+ top_p=top_p,
644
+ top_k=top_k,
645
+ temperature=temperature,
646
+ num_beams=1,
647
+ streamer=streamer,
648
+ )
649
+
650
+ # Generate text in a separate thread to avoid blocking
651
+ thread = Thread(target=model.generate, kwargs=generate_kwargs)
652
+ thread.start()
653
+
654
+ # Stream response
655
+ buffer = ""
656
+ for new_text in streamer:
657
+ buffer += new_text
658
+ yield buffer
659
+
660
+ except Exception as e:
661
+ yield f"An error occurred: {str(e)}"
662
+
663
+ # Gradio setup
664
  chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
665
 
666
  with gr.Blocks(css=CSS, theme="soft") as demo:
 
687
  ),
688
  gr.Slider(
689
  minimum=1024,
690
+ maximum=4096, # Reduced to a more manageable value
691
  step=1024,
692
+ value=4096,
693
  label="Max new tokens",
694
  render=False,
695
  ),
 
710
  render=False,
711
  ),
712
  ],
713
+ # examples=[
714
+ # ["Write a 5000-word comprehensive guide on machine learning for beginners."],
715
+ # ["Create a detailed 3000-word business plan for a sustainable energy startup."],
716
+ # ["Compose a 2000-word short story set in a futuristic underwater city."],
717
+ # ["Develop a 4000-word research proposal on the potential effects of climate change on global food security."],
718
+ # ],
719
  cache_examples=False,
720
  )
721