IlyasMoutawwakil HF staff commited on
Commit
efb091c
·
1 Parent(s): b430e89

add diffusion

Browse files
Files changed (2) hide show
  1. app.py +4 -1
  2. config_store.py +6 -1
app.py CHANGED
@@ -22,6 +22,9 @@ from optimum_benchmark.backends.openvino.utils import (
22
  from optimum_benchmark.backends.transformers_utils import (
23
  TASKS_TO_AUTO_MODEL_CLASS_NAMES,
24
  )
 
 
 
25
  from optimum_benchmark import (
26
  Benchmark,
27
  BenchmarkConfig,
@@ -41,7 +44,7 @@ BACKENDS = ["pytorch", "openvino"]
41
  BENCHMARKS_HF_TOKEN = os.getenv("BENCHMARKS_HF_TOKEN")
42
  BENCHMARKS_REPO_ID = "optimum-benchmark/OpenVINO-Benchmarks"
43
  TASKS = set(TASKS_TO_OVMODELS.keys() | TASKS_TO_OVPIPELINES) & set(
44
- TASKS_TO_AUTO_MODEL_CLASS_NAMES.keys()
45
  )
46
 
47
 
 
22
  from optimum_benchmark.backends.transformers_utils import (
23
  TASKS_TO_AUTO_MODEL_CLASS_NAMES,
24
  )
25
+ from optimum_benchmark.backends.diffusers_utils import (
26
+ TASKS_TO_AUTO_PIPELINE_CLASS_NAMES,
27
+ )
28
  from optimum_benchmark import (
29
  Benchmark,
30
  BenchmarkConfig,
 
44
  BENCHMARKS_HF_TOKEN = os.getenv("BENCHMARKS_HF_TOKEN")
45
  BENCHMARKS_REPO_ID = "optimum-benchmark/OpenVINO-Benchmarks"
46
  TASKS = set(TASKS_TO_OVMODELS.keys() | TASKS_TO_OVPIPELINES) & set(
47
+ TASKS_TO_AUTO_MODEL_CLASS_NAMES.keys() | TASKS_TO_AUTO_PIPELINE_CLASS_NAMES.keys()
48
  )
49
 
50
 
config_store.py CHANGED
@@ -60,7 +60,12 @@ def get_inference_config():
60
  "inference.generate_kwargs": gr.Textbox(
61
  label="inference.generate_kwargs",
62
  value="{'max_new_tokens': 32, 'min_new_tokens': 32}",
63
- info="Additional python dict of kwargs to pass to the generate function",
 
 
 
 
 
64
  ),
65
  }
66
 
 
60
  "inference.generate_kwargs": gr.Textbox(
61
  label="inference.generate_kwargs",
62
  value="{'max_new_tokens': 32, 'min_new_tokens': 32}",
63
+ info="Additional python dict of kwargs to pass to the generate method",
64
+ ),
65
+ "inference.call_kwargs": gr.Textbox(
66
+ label="inference.call_kwargs",
67
+ value="{'num_inference_steps': 4}",
68
+ info="Additional python dict of kwargs to pass to the __call__ method",
69
  ),
70
  }
71