TheoLvs commited on
Commit
e961708
·
1 Parent(s): 103fbf9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -7
app.py CHANGED
@@ -26,12 +26,12 @@ DEFAULT_PARAMS = {
26
  "test_seed": 42, # must be non-negative
27
  },
28
  "image":{
29
- "dataset_name": "QuotaClimat/frugalaichallenge-image-train",
30
  "test_size": 0.2, # must be between 0 and 1
31
  "test_seed": 42, # must be non-negative
32
  },
33
  "audio":{
34
- "dataset_name": "QuotaClimat/frugalaichallenge-audio-train",
35
  "test_size": 0.2, # must be between 0 and 1
36
  "test_seed": 42, # must be non-negative
37
  }
@@ -61,19 +61,33 @@ def evaluate_model(task: str, space_url: str):
61
 
62
  results = response.json()
63
 
64
- # Check for required keys
65
- required_keys = {
66
- "username", "space_url", "submission_timestamp", "model_description",
67
- "accuracy", "energy_consumed_wh", "emissions_gco2eq", "emissions_data",
68
  "api_route", "dataset_config"
69
  }
 
 
 
 
 
 
 
 
70
 
71
  missing_keys = required_keys - set(results.keys())
72
  if missing_keys:
73
  return None, None, None, gr.Warning(f"API response missing required keys: {', '.join(missing_keys)}")
74
 
 
 
 
 
 
 
75
  return (
76
- results["accuracy"],
77
  results["emissions_gco2eq"],
78
  results["energy_consumed_wh"],
79
  results
 
26
  "test_seed": 42, # must be non-negative
27
  },
28
  "image":{
29
+ "dataset_name": "pyronear/pyro-sdis",
30
  "test_size": 0.2, # must be between 0 and 1
31
  "test_seed": 42, # must be non-negative
32
  },
33
  "audio":{
34
+ "dataset_name": "rfcx/frugalai",
35
  "test_size": 0.2, # must be between 0 and 1
36
  "test_seed": 42, # must be non-negative
37
  }
 
61
 
62
  results = response.json()
63
 
64
+ # Check for required keys based on task
65
+ base_required_keys = {
66
+ "username", "space_url", "submission_timestamp", "model_description",
67
+ "energy_consumed_wh", "emissions_gco2eq", "emissions_data",
68
  "api_route", "dataset_config"
69
  }
70
+
71
+ # Add task-specific accuracy keys
72
+ if task == "image":
73
+ accuracy_keys = {"classification_accuracy", "mean_iou"}
74
+ else: # text and audio
75
+ accuracy_keys = {"accuracy"}
76
+
77
+ required_keys = base_required_keys | accuracy_keys
78
 
79
  missing_keys = required_keys - set(results.keys())
80
  if missing_keys:
81
  return None, None, None, gr.Warning(f"API response missing required keys: {', '.join(missing_keys)}")
82
 
83
+ # Return appropriate accuracy metric based on task
84
+ if task == "image":
85
+ accuracy = results["classification_accuracy"] # For display in UI
86
+ else:
87
+ accuracy = results["accuracy"]
88
+
89
  return (
90
+ accuracy,
91
  results["emissions_gco2eq"],
92
  results["energy_consumed_wh"],
93
  results