Skip to content

Commit 4426b67

Browse files
committed
client:指令格式调试
1 parent ee74444 commit 4426b67

File tree

9 files changed

+92
-51
lines changed

9 files changed

+92
-51
lines changed

excuter/op-mem-ompsimd/src/client/main.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,20 @@ int main()
3030
deepx::op::Op op;
3131
op.load(buffer);
3232

33-
shared_ptr<deepx::op::Op> opsrc = opfactory.get_op(op);
33+
34+
if (opfactory.ops.find(op.name)==opfactory.ops.end()){
35+
cout<<"<op> "<<op.name<<" not found"<<endl;
36+
return;
37+
}
38+
auto &type_map = opfactory.ops.find(op.name)->second;
39+
if (type_map.find(op.dtype)==type_map.end()){
40+
cout<<"<op>"<<op.name<<" "<<op.dtype<<" not found"<<endl;
41+
return;
42+
}
43+
auto src = type_map.find(op.dtype)->second;
3444

35-
(*opsrc).init(op.name, op.dtype, op.args, op.returns, op.require_grad, op.args_grad, op.returns_grad);
36-
(*opsrc).forward(mem);
45+
(*src).init(op.name, op.dtype, op.args, op.returns, op.require_grad, op.args_grad, op.returns_grad);
46+
(*src).forward(mem);
3747
};
3848
server.start();
3949
return 0;

excuter/op-mem-ompsimd/src/client/udpserver.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ namespace client
4444
0, (struct sockaddr *)&cliaddr,
4545
&len);
4646
buffer[n] = '\0';
47-
std::cout << "Recv "<<n<<"bytes:" << buffer << std::endl;
47+
std::cout << ">>"<< buffer << std::endl;
4848
func(buffer);
4949
}
5050

excuter/op-mem-ompsimd/src/deepx/mem/mem.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,19 +117,20 @@ namespace deepx::mem
117117
tempidx++;
118118
return cloned;
119119
}
120-
template <typename T>
120+
121121
bool existstensor(const string &name) const
122122
{
123123
return mem.find(name) != mem.end();
124124
}
125125

126-
template <typename T>
126+
template <typename T>
127127
shared_ptr<Tensor<T>> gettensor(const string &name) const
128128
{
129129
auto ptr = mem.at(name);
130130
return std::static_pointer_cast<Tensor<T>>(ptr);
131131
}
132-
132+
133+
133134
// 获取多个张量
134135
template <typename T>
135136
vector<Tensor<T> *> gettensors(const vector<string> &names) const

excuter/op-mem-ompsimd/src/deepx/op/init.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ namespace deepx::op{
1818
}
1919
void forward(mem::Mem &mem) override{
2020
auto output = mem.gettensor<T>(this->returns[0]).get();
21-
T low = mem.getarg<T>(this->args[1]);
22-
T high = mem.getarg<T>(this->args[2]);
21+
T low = mem.getarg<T>(this->args[0]);
22+
T high = mem.getarg<T>(this->args[1]);
2323
tensorfunc::uniform(*output,low,high);
2424
}
2525
void backward(mem::Mem &mem) override{

excuter/op-mem-ompsimd/src/deepx/op/new.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ namespace deepx::op{
1818
this->init("newtensor",dtype<T>::name(), args, {}, false, {}, {});
1919
}
2020
void forward(mem::Mem &mem) override{
21-
string name= this->args[0];
22-
vector<int> shape=mem.getvector<int>(this->args[1]);
21+
string name= this->returns[0];
22+
vector<int> shape=mem.getvector<int32_t>(this->args[0]);
2323
Tensor<T> t=tensorfunc::New<T>(shape);
2424
mem.addtensor(name,t);
2525
}

excuter/op-mem-ompsimd/src/deepx/op/op.hpp

Lines changed: 56 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
#include <memory>
77
#include <string>
88
#include <iostream>
9-
#include <yaml-cpp/yaml.h>
9+
#include <sstream>
10+
1011
#include "deepx/tensor.hpp"
1112
#include "deepx/mem/mem.hpp"
1213
#include "deepx/dtype.hpp"
@@ -50,26 +51,62 @@ namespace deepx::op
5051
throw std::runtime_error("backward not implemented");
5152
}
5253

53-
void load(const char *yml)
54-
{
55-
YAML::Node config = YAML::Load(yml);
56-
name = config["name"].as<std::string>();
57-
dtype = config["dtype"].as<std::string>();
58-
if (config["args"])
59-
{
60-
args = config["args"].as<std::vector<std::string>>();
61-
}
62-
if (config["returns"])
63-
{
64-
returns = config["returns"].as<std::vector<std::string>>();
54+
void load(const char* str) {
55+
// 格式: opname dtype args returns require_grad args_grad returns_grad
56+
// 例子: "add float32 a,b c 1 a.grad,b.grad c.grad"
57+
// 或者: "add float32 a,b c 0"
58+
// 或者: "print a"
59+
60+
stringstream ss(str);
61+
string token;
62+
63+
// 读取操作名
64+
ss >> name;
65+
66+
// 读取数据类型
67+
ss >> dtype;
68+
69+
// 读取参数列表 (逗号分隔)
70+
ss >> token;
71+
args.clear();
72+
stringstream args_ss(token);
73+
string arg;
74+
while (getline(args_ss, arg, ',')) {
75+
args.push_back(arg);
6576
}
66-
if (config["args_grad"])
67-
{
68-
args_grad = config["args_grad"].as<std::vector<std::string>>();
77+
78+
// 读取返回值列表
79+
ss >> token;
80+
returns.clear();
81+
stringstream returns_ss(token);
82+
string ret;
83+
while (getline(returns_ss, ret, ',')) {
84+
returns.push_back(ret);
6985
}
70-
if (config["returns_grad"])
71-
{
72-
returns_grad = config["returns_grad"].as<std::vector<std::string>>();
86+
87+
// 读取是否需要梯度
88+
ss >> token;
89+
require_grad = (token == "1");
90+
91+
// 如果需要梯度,继续读取梯度变量名
92+
if (require_grad && ss >> token) {
93+
// 读取参数梯度列表
94+
args_grad.clear();
95+
stringstream args_grad_ss(token);
96+
string arg_grad;
97+
while (getline(args_grad_ss, arg_grad, ',')) {
98+
args_grad.push_back(arg_grad);
99+
}
100+
101+
// 读取返回值梯度列表
102+
if (ss >> token) {
103+
returns_grad.clear();
104+
stringstream returns_grad_ss(token);
105+
string ret_grad;
106+
while (getline(returns_grad_ss, ret_grad, ',')) {
107+
returns_grad.push_back(ret_grad);
108+
}
109+
}
73110
}
74111
}
75112
void init(const string &opname,

excuter/op-mem-ompsimd/src/deepx/op/opfactory.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "deepx/op/init.hpp"
66
#include "deepx/op/new.hpp"
77
#include "deepx/op/arg.hpp"
8+
#include "deepx/op/print.hpp"
89
namespace deepx::op
910
{
1011
//new
@@ -38,6 +39,10 @@ namespace deepx::op
3839
register_constant(opfactory);
3940
register_arange(opfactory);
4041
}
42+
//print
43+
void register_print(OpFactory &opfactory){
44+
opfactory.add_op(Print<float>());
45+
}
4146
//elementwise
4247
void register_add(OpFactory &opfactory){
4348
opfactory.add_op(Add<float>());
@@ -110,6 +115,7 @@ namespace deepx::op
110115
int register_all(OpFactory &opfactory){
111116
register_new(opfactory);
112117
register_init(opfactory);
118+
register_print(opfactory);
113119
register_elementwise_op(opfactory);
114120
register_concat(opfactory);
115121
register_matmul(opfactory);

excuter/op-mem-ompsimd/src/deepx/op/opfactory.hpp

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,28 +14,14 @@ namespace deepx::op
1414

1515
class OpFactory
1616
{
17-
private:
18-
std::unordered_map<std::string, Op_dtype> ops;
19-
2017
public:
21-
18+
std::unordered_map<std::string, Op_dtype> ops;
2219
template <typename T>
2320
void add_op(const T &op)
2421
{
2522
ops[op.name][op.dtype] = std::make_shared<T>(op);
2623
}
27-
28-
std::shared_ptr<Op> get_op(const Op &op)
29-
{
30-
auto &type_map = ops[op.name];
31-
auto it = type_map.find(op.dtype);
32-
if (it != type_map.end())
33-
{
34-
auto src = it->second;
35-
return src;
36-
}
37-
return nullptr;
38-
}
24+
3925

4026
void print(){
4127
cout<<"support op:"<<endl;

excuter/op-mem-ompsimd/src/deepx/op/print.hpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,22 @@
22
#define DEEPX_OP_PRINT_HPP
33

44
#include "deepx/op/op.hpp"
5+
#include "deepx/tensorfunc/print.hpp"
56

67
namespace deepx::op{
7-
8+
template <typename T>
89
class Print : public Op{
910
public:
1011
Print(){
11-
this->init("print","", {}, {}, false, {}, {});
12+
this->init("print","any", {}, {}, false, {}, {});
1213
}
1314
void forward(mem::Mem &mem) override{
1415
string name=this->returns[0];
15-
if (mem.existtensor(name)){
16+
if (mem.existstensor(name)){
1617
auto t=mem.gettensor<T>(name);
17-
cout<<t<<endl;
18+
tensorfunc::print<T>(*t);
1819
}else{
19-
throw std::runtime_error("Print op does not support backward");
20+
cout<<"<print> "<<name<<" not found"<<endl;
2021
}
2122
}
2223
void backward(mem::Mem &mem) override{

0 commit comments

Comments
 (0)