|
6 | 6 | #include <memory> |
7 | 7 | #include <string> |
8 | 8 | #include <iostream> |
9 | | -#include <yaml-cpp/yaml.h> |
| 9 | +#include <sstream> |
| 10 | + |
10 | 11 | #include "deepx/tensor.hpp" |
11 | 12 | #include "deepx/mem/mem.hpp" |
12 | 13 | #include "deepx/dtype.hpp" |
@@ -50,26 +51,62 @@ namespace deepx::op |
50 | 51 | throw std::runtime_error("backward not implemented"); |
51 | 52 | } |
52 | 53 |
|
53 | | - void load(const char *yml) |
54 | | - { |
55 | | - YAML::Node config = YAML::Load(yml); |
56 | | - name = config["name"].as<std::string>(); |
57 | | - dtype = config["dtype"].as<std::string>(); |
58 | | - if (config["args"]) |
59 | | - { |
60 | | - args = config["args"].as<std::vector<std::string>>(); |
61 | | - } |
62 | | - if (config["returns"]) |
63 | | - { |
64 | | - returns = config["returns"].as<std::vector<std::string>>(); |
| 54 | + void load(const char* str) { |
| 55 | + // 格式: opname dtype args returns require_grad args_grad returns_grad |
| 56 | + // 例子: "add float32 a,b c 1 a.grad,b.grad c.grad" |
| 57 | + // 或者: "add float32 a,b c 0" |
| 58 | + // 或者: "print a" |
| 59 | + |
| 60 | + stringstream ss(str); |
| 61 | + string token; |
| 62 | + |
| 63 | + // 读取操作名 |
| 64 | + ss >> name; |
| 65 | + |
| 66 | + // 读取数据类型 |
| 67 | + ss >> dtype; |
| 68 | + |
| 69 | + // 读取参数列表 (逗号分隔) |
| 70 | + ss >> token; |
| 71 | + args.clear(); |
| 72 | + stringstream args_ss(token); |
| 73 | + string arg; |
| 74 | + while (getline(args_ss, arg, ',')) { |
| 75 | + args.push_back(arg); |
65 | 76 | } |
66 | | - if (config["args_grad"]) |
67 | | - { |
68 | | - args_grad = config["args_grad"].as<std::vector<std::string>>(); |
| 77 | + |
| 78 | + // 读取返回值列表 |
| 79 | + ss >> token; |
| 80 | + returns.clear(); |
| 81 | + stringstream returns_ss(token); |
| 82 | + string ret; |
| 83 | + while (getline(returns_ss, ret, ',')) { |
| 84 | + returns.push_back(ret); |
69 | 85 | } |
70 | | - if (config["returns_grad"]) |
71 | | - { |
72 | | - returns_grad = config["returns_grad"].as<std::vector<std::string>>(); |
| 86 | + |
| 87 | + // 读取是否需要梯度 |
| 88 | + ss >> token; |
| 89 | + require_grad = (token == "1"); |
| 90 | + |
| 91 | + // 如果需要梯度,继续读取梯度变量名 |
| 92 | + if (require_grad && ss >> token) { |
| 93 | + // 读取参数梯度列表 |
| 94 | + args_grad.clear(); |
| 95 | + stringstream args_grad_ss(token); |
| 96 | + string arg_grad; |
| 97 | + while (getline(args_grad_ss, arg_grad, ',')) { |
| 98 | + args_grad.push_back(arg_grad); |
| 99 | + } |
| 100 | + |
| 101 | + // 读取返回值梯度列表 |
| 102 | + if (ss >> token) { |
| 103 | + returns_grad.clear(); |
| 104 | + stringstream returns_grad_ss(token); |
| 105 | + string ret_grad; |
| 106 | + while (getline(returns_grad_ss, ret_grad, ',')) { |
| 107 | + returns_grad.push_back(ret_grad); |
| 108 | + } |
| 109 | + } |
73 | 110 | } |
74 | 111 | } |
75 | 112 | void init(const string &opname, |
|
0 commit comments