syngen / app.py
Pavel_Henrykhsen
Initial commit
42e611f
# streamlit_app.py
import streamlit as st
import os
import pandas as pd
from syngen.ml.worker import Worker
import queue
from loguru import logger
import threading
# Use streamlit to run the application
if __name__ == "__main__":
st.title("HuggingFace Streamlit App with Syngen")
st.write("Upload CSV files, define relationships, and train your model.")
# Create a queue for the logs
log_queue = queue.Queue()
def log_sink(message):
log_queue.put(message.record["message"])
logger.add(log_sink)
# Path to store the uploaded files
UPLOAD_DIRECTORY = "uploaded_files"
if not os.path.exists(UPLOAD_DIRECTORY):
os.makedirs(UPLOAD_DIRECTORY)
# Define file uploader
uploaded_files = st.file_uploader(
"Upload CSV files", type="csv", accept_multiple_files=True)
dataframes = {}
if uploaded_files:
for uploaded_file in uploaded_files:
# Save file to local directory
file_path = os.path.join(UPLOAD_DIRECTORY, uploaded_file.name)
with open(file_path, 'wb') as f:
f.write(uploaded_file.getvalue())
df = pd.read_csv(file_path)
dataframes[uploaded_file.name] = df
st.write(f"Preview of {uploaded_file.name}:", df.head())
# YAML Configuration Editor
st.subheader('YAML Configuration')
yaml_config = st.text_area(
"Define the relationships between the CSV files:", "")
if yaml_config:
st.code(yaml_config, language="yaml")
@logger.catch
def train_model():
logger.info("Starting model training...")
for uploaded_file in uploaded_files:
file_path = os.path.join(UPLOAD_DIRECTORY, uploaded_file.name)
settings = {
"source": file_path,
"epochs": 2,
"drop_null": False,
"print_report": False,
"row_limit": None,
"batch_size": 32
}
worker = Worker(
table_name=uploaded_file.name,
settings=settings,
metadata_path=None,
log_level='DEBUG',
type="train"
)
worker.launch_train()
logger.info("Model training completed.")
# Training Button
if st.button('Start Model Training'):
if uploaded_files and yaml_config:
# 1. Save YAML configuration to a file
with open("config.yaml", "w") as f:
f.write(yaml_config)
log_display = st.empty() # create an empty slot to display logs
training_thread = threading.Thread(target=train_model)
training_thread.start()
st.info("Training started. Please wait...")
while not log_queue.empty():
log_display.text(log_queue.get())
# 3. Save the model binaries
# model.save("model.bin")
# 4. Display links to download files
st.markdown(
"Download [Generated File](./generated_file.csv) and [Model Binaries](./model.bin).")
else:
st.warning(
"Please upload CSV files and provide a valid YAML configuration.")