import os
import sys
import time
import urllib.request
import json
import random
import requests
from voice import voice_dict
from dotenv import load_dotenv
load_dotenv('credentials.env')
OPENAPI_KEY = os.getenv('OPENAPI_KEY')
CLOVA_VOICE_Client_ID = os.getenv('CLOVA_VOICE_Client_ID')
CLOVA_VOICE_Client_Secret = os.getenv('CLOVA_VOICE_Client_Secret')
PAPAGO_Translate_Client_ID = os.getenv('PAPAGO_Translate_Client_ID')
PAPAGO_Translate_Client_Secret = os.getenv('PAPAGO_Translate_Client_Secret')
mubert_pat = os.getenv('mubert_pat')
SUMMARY_Client_ID = os.getenv('SUMMARY_Client_ID')
SUMMARY_Client_Secret = os.getenv('SUMMARY_Client_Secret')

import time
import os
import subprocess
from tempfile import NamedTemporaryFile

import torch
from audiocraft.data.audio import audio_write
from audiocraft.models import MusicGen

# Using small model, better results would be obtained with `medium` or `large`.
model = MusicGen.get_pretrained('facebook/musicgen-melody')
model.set_generation_params(
    use_sampling=True,
    top_k=250,
    duration=30
)


def get_voice(input_text:str, gender:str="female", age_group:str="youth", speed:int=1, pitch:int=1, alpha:int=-1, filename="voice.mp3"):
    """
    gender: female or male
    age_group: child, teenager, youth, middle_aged
    """
    speaker = random.choice(voice_dict[gender][age_group])
    data = {"speaker":speaker, "text":input_text, 'speed':speed, 'pitch':pitch, 'alpha':alpha}
    url = "https://naveropenapi.apigw.ntruss.com/tts-premium/v1/tts"
    headers = {
        "X-NCP-APIGW-API-KEY-ID": CLOVA_VOICE_Client_ID,
        "X-NCP-APIGW-API-KEY": CLOVA_VOICE_Client_Secret,
    }
    response = requests.post(url, headers=headers, data=data)
    if response.status_code == 200:
        print("TTS mp3 저장")
        response_body = response.content
        with open(filename, 'wb') as f:
            f.write(response_body)
    else:
        print("Error Code: " + str(response.status_code))
        print("Error Message: " + str(response.json()))
    return filename
    
def translate_text(text:str):
    
    encText = urllib.parse.quote(text)
    data = f"source=ko&target=en&text={encText}"
    url = "https://naveropenapi.apigw.ntruss.com/nmt/v1/translation"
    
    request = urllib.request.Request(url)
    request.add_header("X-NCP-APIGW-API-KEY-ID", PAPAGO_Translate_Client_ID)
    request.add_header("X-NCP-APIGW-API-KEY", PAPAGO_Translate_Client_Secret)
    
    try:
        response = urllib.request.urlopen(request, data=data.encode("utf-8"))
        response_body = response.read()
        return json.loads(response_body.decode('utf-8'))['message']['result']['translatedText']
    except urllib.error.HTTPError as e:
        return f"Error Code: {e.code}"


# -

def get_summary(input_text:str, summary_count:int = 5):
    if len(input_text) > 2000:
        input_text = input_text[:2000]
    input_text = input_text.strip()
        
    data = {
          "document": {
            "content": input_text
          },
          "option": {
            "language": "ko",
            "model": "general",
            "tone": "0",
            "summaryCount": summary_count
          }
        }
    url = "https://naveropenapi.apigw.ntruss.com/text-summary/v1/summarize"
    headers = {
        "X-NCP-APIGW-API-KEY-ID": SUMMARY_Client_ID,
        "X-NCP-APIGW-API-KEY": SUMMARY_Client_Secret,
        "Content-Type": "application/json"
    }
    response = requests.post(url, headers=headers, data=json.dumps(data))
    if response.status_code == 200:
        return ' '.join(response.json()['summary'].split('\n'))
    elif response.status_code == 400 and response.json()['error']['errorCode'] == 'E100':
        return input_text
    else:
        print("Error Code: " + str(response.status_code))
        print("Error Message: " + str(response.json()))

def get_mubert_music(text, duration=300):
    print('original text length: ', len(text))
    summary = get_summary(text, 3)
    print('summary text length: ', len(summary))
    translated_text = translate_text(summary)
    print('translated_text length: ', len(translated_text))
    if len(translated_text) > 200:
        translated_text = translated_text[:200]
        
    r = requests.post('https://api-b2b.mubert.com/v2/TTMRecordTrack', 
        json={
            "method":"TTMRecordTrack",
            "params":
            {
                "text":translated_text,
                "pat":mubert_pat,
                "mode":"track",
                "duration":duration, 
                "bitrate":128
            }
        })

    rdata = json.loads(r.text)
    if rdata['status'] == 1:
        url = rdata['data']['tasks'][0]['download_link']
        
        done = False
        while not done:
            r = requests.post('https://api-b2b.mubert.com/v2/TrackStatus', 
            json={
                "method":"TrackStatus",
                "params":
                        {
                            "pat":mubert_pat
                        }
            })
            
            if r.json()['data']['tasks'][0]['task_status_text'] == 'Done':
                done = True
                time.sleep(2)
            
        # return url
        local_filename = "mubert_music.mp3"
        headers = {
            "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3"
        }
        
        download = False
        while not download:
            response = requests.get(url, stream=True, headers=headers)
            
            if response.status_code == 200:
                download=True
                time.sleep(1)

        if response.status_code == 404:
            print("파일이 존재하지 않습니다.")
            return
        elif response.status_code != 200:
            print(f"파일 다운로드에 실패하였습니다. 에러 코드: {response.status_code}")
            return

        with open(local_filename, "wb") as f:
            for chunk in response.iter_content(chunk_size=8192):
                if chunk:
                    f.write(chunk)
        print(f"{local_filename} 파일이 저장되었습니다.")
        return local_filename

def get_musicgen_music(text, duration=300):
    file_name = 'musicgen_output.wav'
    print('original text length: ', len(text))
    summary = get_summary(text, 3)
    print('summary text length: ', len(summary))
    translated_text = translate_text(summary)
    print('translated_text length: ', len(translated_text))
    if len(translated_text) > 200:
        translated_text = translated_text[:200]
    print(translated_text)
    start = time.time()
    overlap = 5
    music_length = 30
    target_length = duration
    desc = [translated_text]
    print(model.sample_rate)
    output = model.generate(descriptions=desc, progress=True)
    while music_length < target_length:
        last_sec = output[:, :, int(-overlap*model.sample_rate):]
        cont = model.generate_continuation(last_sec, model.sample_rate, descriptions=desc, progress=True)
        output = torch.cat([output[:, :, :int(-overlap*model.sample_rate)], cont], 2)
        music_length = output.shape[2] / model.sample_rate
    if music_length > target_length:
        output = output[:, :, :int(target_length*model.sample_rate)]

    output = output.detach().cpu().float()[0]
    audio_write(
        file_name, output, model.sample_rate, strategy="loudness",
        loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)

    print(f'Elapsed time: {time.time() - start}')
    return file_name

# def get_story(first_sentence:str, history, num_sentences:int):
#     response = requests.post("https://api.openai.com/v1/chat/completions", 
#                             headers={"Content-Type": "application/json", "Authorization": f"Bearer {OPENAPI_KEY}"},
#                             data=json.dumps({
#                                 "model": "gpt-3.5-turbo",
#                                 "messages": [{"role": "system", "content": "You are a helpful assistant."}, 
#                                             {"role": "user", "content": f"""I will provide the first sentence of the novel, and please write {num_sentences} sentences continuing the story in a first-person protagonist's perspective in Korean. Don't number the sentences.
#                                             \n\nStory: {first_sentence}"""}]
#                             }))
#     print(response.json())
#     return response.json()['choices'][0]['message']['content']

def get_story(first_sentence:str, num_sentences:int, chatbot=[], history=[]):
    history.append(first_sentence)
    # make a POST request to the API endpoint using the requests.post method, passing in stream=True
    response = requests.post("https://api.openai.com/v1/chat/completions", 
                            headers={"Content-Type": "application/json", "Authorization": f"Bearer {OPENAPI_KEY}"},
                             stream=True,
                            data=json.dumps({
                                "stream": True,
                                "model": "gpt-3.5-turbo",
                                "messages": [{"role": "system", "content": "You are a helpful assistant."}, 
                                            {"role": "user", "content": f"""I will provide the first sentence of the novel, and please write {num_sentences} sentences continuing the story in a first-person protagonist's perspective in Korean. Don't number the sentences.
                                            \n\nFirst sentence: {first_sentence}"""}]
                            }))

    token_counter = 0 
    partial_words = "" 
    counter=0
    for chunk in response.iter_lines():
        #Skipping first chunk
        if counter == 0:
            counter+=1
            continue
        # check whether each line is non-empty
        if chunk.decode() :
            chunk = chunk.decode()
          # decode each line as response data is in bytes
            if len(chunk) > 12 and "content" in json.loads(chunk[6:])['choices'][0]['delta']:
                partial_words = partial_words + json.loads(chunk[6:])['choices'][0]["delta"]["content"]
                if token_counter == 0:
                    history.append(" " + partial_words)
                else:
                    history[-1] = partial_words
                chat = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2) ]  # convert to tuples of list
                token_counter+=1
                yield chat, history, response


def get_voice_filename(text, gender, age, speed, pitch, alpha):
    filename = None
    if gender == '남성':
        if age == "어린이":
            filename = get_voice(text, gender="male", age_group="child", speed=speed, pitch=pitch, alpha=alpha, filename="voice.mp3")
        elif age == "청소년":
            filename = get_voice(text, gender="male", age_group="teenager", speed=speed, pitch=pitch, alpha=alpha, filename="voice.mp3")
        elif age == "청년":
            filename = get_voice(text, gender="male", age_group="youth", speed=speed, pitch=pitch, alpha=alpha, filename="voice.mp3")
        elif age == "중년":
            filename = get_voice(text, gender="male", age_group="middle_aged", speed=speed, pitch=pitch, alpha=alpha, filename="voice.mp3")
    else:
        if age == "어린이":
            filename = get_voice(text, gender="female", age_group="child", speed=speed, pitch=pitch, alpha=alpha, filename="voice.mp3")
        elif age == "청소년":
            filename = get_voice(text, gender="female", age_group="teenager", speed=speed, pitch=pitch, alpha=alpha, filename="voice.mp3")
        elif age == "청년":
            filename = get_voice(text, gender="female", age_group="youth", speed=speed, pitch=pitch, alpha=alpha, filename="voice.mp3")
        elif age == "중년":
            filename = get_voice(text, gender="female", age_group="middle_aged", speed=speed, pitch=pitch, alpha=alpha, filename="voice.mp3")
    return filename