diff --git a/mikazuki/app/api.py b/mikazuki/app/api.py index c2e515c6..ae423415 100644 --- a/mikazuki/app/api.py +++ b/mikazuki/app/api.py @@ -126,7 +126,6 @@ async def create_toml_file(request: Request): config: dict = json.loads(json_data.decode("utf-8")) train_utils.fix_config_types(config) - gpu_ids = config.pop("gpu_ids", None) suggest_cpu_threads = 8 if len(train_utils.get_total_images(config["train_data_dir"])) > 200 else 2 @@ -163,7 +162,11 @@ async def create_toml_file(request: Request): with open(toml_file, "w", encoding="utf-8") as f: f.write(toml.dumps(config)) - + if not os.path.exists(config['output_dir']): + os.makedirs(config['output_dir'], exist_ok=True) + with open(os.path.join(config['output_dir'],"config.toml"),'w', encoding="utf-8") as f: + f.write(toml.dumps(config)) + result = process.run_train(toml_file, trainer_file, gpu_ids, suggest_cpu_threads) return result diff --git a/mikazuki/process.py b/mikazuki/process.py index 7fd37ed8..35524f57 100644 --- a/mikazuki/process.py +++ b/mikazuki/process.py @@ -55,4 +55,4 @@ def _run(): coro = asyncio.to_thread(_run) asyncio.create_task(coro) - return APIResponse(status="success", message=f"Training started / 训练开始 ID: {task.task_id}") + return APIResponse(status="success", message=f"Training started / 训练开始 ID: {task.task_id}",data={"task_id": task.task_id})