import requests
import zipfile 
import tempfile

def group_tests_by_duration(file_path: str) -> dict:
    # Define the buckets and their labels
    buckets = [(0, 5), (5, 10), (10, 15), (15, 20), (20, float('inf'))]
    bucket_names = ["0-5s", "5-10s", "10-15s", "15-20s", ">20s"]
    test_groups = {name: [] for name in bucket_names}
    
    # Process the file with error handling
    with open(file_path, 'r') as file:
        for line in file:
            try:
                parts = line.split()
                # Extracting duration and test name, ignoring lines that don't match expected format
                if len(parts) >= 3 and 's' in parts[0]:
                    duration = float(parts[0].rstrip('s'))  # Remove 's' and convert to float
                    test_name = ' '.join(parts[2:])  # Join back the test name parts
                    
                    # Assign test to the correct bucket based on duration
                    for (start, end), bucket_name in zip(buckets, bucket_names):
                        if start <= duration < end:
                            test_groups[bucket_name].append((duration, test_name))
                            break
            except ValueError:
                # Skip lines that cannot be parsed properly
                continue
    
    return test_groups


def extract_top_n_tests(file_path, n=10):
    test_durations = []

    # Reading and processing the file
    with open(file_path, 'r') as file:
        for line in file:
            parts = line.split()
            if len(parts) >= 3 and parts[1] == 'call':
                duration_s = parts[0].rstrip('s')  # Remove the 's' from the duration
                try:
                    duration = float(duration_s)
                    test_name = ' '.join(parts[2:])
                    test_durations.append((duration, test_name))
                except ValueError:
                    # Skip lines that cannot be converted to float
                    continue

    # Sort the list in descending order of duration
    test_durations.sort(reverse=True, key=lambda x: x[0])

    # Extract the top N tests
    top_n_tests = {test[1]: f"{test[0]}s"
                   for i, test in enumerate(test_durations[:n])}

    return top_n_tests


def fetch_test_duration_artifact(repo_id, token, run_id, artifact_name):
    # Construct the API URL
    owner_repo = repo_id.split("/")
    artifacts_url = f'https://api.github.com/repos/{owner_repo[0]}/{owner_repo[1]}/actions/runs/{run_id}/artifacts'

    # Set up the headers with your authentication token
    headers = {'Authorization': f'token {token}'}

    # Send the request to get a list of artifacts from the specified run
    response = requests.get(artifacts_url, headers=headers)
    response.raise_for_status()  # Raise an exception for HTTP error responses

    # Search for the artifact with the specified name
    download_url = None
    for artifact in response.json().get('artifacts', []):
        if artifact['name'] == artifact_name:
            download_url = artifact['archive_download_url']
            break

    if download_url:
        # Download the artifact
        download_response = requests.get(download_url, headers=headers, stream=True)
        download_response.raise_for_status()

        # Save the downloaded artifact to a file
        zip_file_path = f'{artifact_name}.zip'
        with open(zip_file_path, 'wb') as file:
            for chunk in download_response.iter_content(chunk_size=128):
                file.write(chunk)

        # Extract the duration text file
        with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
            # Check if the specified file exists in the zip
            zip_files = zip_ref.namelist()
            for file in zip_files:
                if "duration" in file:
                    zip_ref.extract(file, ".")
                    break
        return file
    
    else:
        raise ValueError("Error 🥲")
    
def format_to_markdown_str(test_bucket_map, top_n_slow_tests, repo_id, run_id, artifact_name):
    run_url = f"https://github.com/{repo_id}/actions/runs/{run_id}/"
    markdown_str = f"""
## Top {len(top_n_slow_tests)} slow test for {artifact_name}\n
"""
    for test, duration in top_n_slow_tests.items():
        markdown_str += f"* {test.split('/')[-1]}: {duration}\n"
    
    markdown_str += """
## Bucketed durations of the tests\n
"""
    for bucket, num_tests in test_bucket_map.items():
        if ">" in bucket:
            bucket = f"\{bucket}"
        markdown_str += f"* {bucket}: {num_tests} tests\n"
    
    markdown_str += f"\nRun URL: [{run_url}]({run_url})."
    
    return markdown_str
    

def analyze_tests(repo_id, token, run_id, artifact_name, top_n):
    test_duration_file = fetch_test_duration_artifact(repo_id=repo_id, token=token, run_id=run_id, artifact_name=artifact_name)
    
    grouped_tests_map = group_tests_by_duration(test_duration_file)
    test_bucket_map = {bucket: len(tests) for bucket, tests in grouped_tests_map.items()}
    print(test_bucket_map)
    top_n_slow_tests = extract_top_n_tests(test_duration_file, n=top_n)
    print(top_n_slow_tests)

    return format_to_markdown_str(test_bucket_map, top_n_slow_tests, repo_id, run_id, artifact_name)