Skip to content

Commit ae3c80e

Browse files
committed
client:解决了并行问题
1 parent cb2e54e commit ae3c80e

3 files changed

Lines changed: 63 additions & 33 deletions

File tree

excuter/op-mem-ompsimd/src/client/main.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#include <mutex>
2+
13
#include <deepx/tensorfunc/init.hpp>
24
#include <deepx/tensorfunc/new.hpp>
35
#include <deepx/tensorfunc/print.hpp>
@@ -12,7 +14,8 @@ using namespace deepx::mem;
1214

1315
int main()
1416
{
15-
Mem mem;
17+
Mem mem;
18+
std::mutex memmutex;
1619
deepx::Tensor<float> tensor = New<float>({1, 2, 3});
1720
uniform(tensor,-1.0f,1.0f);
1821
mem.addtensor("tensor", tensor);
@@ -25,12 +28,10 @@ int main()
2528
deepx::op::OpFactory opfactory;
2629
deepx::op::register_all(opfactory);
2730
opfactory.print();
28-
server.func = [&mem, &opfactory](const char *buffer)
31+
server.func = [&mem, &opfactory, &memmutex](const char *buffer)
2932
{
3033
deepx::op::Op op;
3134
op.load(buffer);
32-
33-
3435
if (opfactory.ops.find(op.name)==opfactory.ops.end()){
3536
cout<<"<op> "<<op.name<<" not found"<<endl;
3637
return;
@@ -43,7 +44,9 @@ int main()
4344
auto src = type_map.find(op.dtype)->second;
4445

4546
(*src).init(op.name, op.dtype, op.args, op.returns, op.require_grad, op.args_grad, op.returns_grad);
47+
memmutex.lock();
4648
(*src).forward(mem);
49+
memmutex.unlock();
4750
};
4851
server.start();
4952
return 0;

excuter/op-mem-ompsimd/src/client/udpserver.cpp

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
#include <sstream>
2+
13
#include "udpserver.hpp"
24

35
namespace client
46
{
7+
using namespace std;
58
udpserver::udpserver(int port)
69
{
710
this->port = port;
@@ -38,16 +41,21 @@ namespace client
3841
}
3942
while (true)
4043
{
41-
len = sizeof(cliaddr); // len is value/result
42-
// 接收消息
43-
n = recvfrom(sockfd, (char *)buffer, 1024,
44-
0, (struct sockaddr *)&cliaddr,
45-
&len);
44+
len = sizeof(cliaddr);
45+
n = recvfrom(sockfd, (char *)buffer, 1024, 0, (struct sockaddr *)&cliaddr, &len);
4646
buffer[n] = '\0';
47-
std::cout << "~"<< buffer;
48-
func(buffer);
47+
48+
// 新增换行拆分逻辑
49+
stringstream ss(buffer);
50+
string line;
51+
while (getline(ss, line)) {
52+
if (!line.empty()) {
53+
cout << "~" << line << endl;
54+
char *IR = const_cast<char *>(line.c_str());
55+
func(IR);
56+
}
57+
}
4958
}
50-
5159
close(sockfd);
5260
}
5361
}

excuter/op-mem-ompsimd/src/deepx/mem/mem.hpp

Lines changed: 40 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ namespace deepx::mem
1818
unordered_map<string, std::any> args;
1919

2020
std::unordered_map<std::string, std::shared_ptr<TensorBase>> mem;
21-
int tempidx=0;
21+
int tempidx = 0;
2222

2323
// template <typename T>
2424
// static std::shared_ptr<void> type_erase(const std::shared_ptr<Tensor<T>> &ptr)
@@ -60,37 +60,56 @@ namespace deepx::mem
6060
template <typename T>
6161
void addarg(const string &name, const T value)
6262
{
63+
if (args.find(name) != args.end())
64+
{
65+
cerr << "Argument already exists: " << name << endl;
66+
}
6367
args[name] = value;
6468
}
6569

6670
template <typename T>
6771
T getarg(const string &name) const
6872
{
73+
if (args.find(name) == args.end())
74+
{
75+
cerr << "Argument not found: " << name << endl;
76+
return T();
77+
}
6978
return any_cast<T>(args.at(name));
7079
}
7180

7281
template <typename T>
7382
void addvector(const string &name, const vector<T> &value)
7483
{
84+
if (args.find(name) != args.end())
85+
{
86+
cerr << "Vector already exists: " << name << endl;
87+
return;
88+
}
7589
args[name] = value;
7690
}
7791

7892
template <typename T>
7993
vector<T> getvector(const string &name) const
8094
{
95+
if (args.find(name) == args.end())
96+
{
97+
cerr << "Vector not found: " << name << endl;
98+
return vector<T>();
99+
}
81100
auto v = any_cast<vector<T>>(args.at(name));
82101
return v;
83102
}
84103

104+
// tensor
85105

86-
//tensor
87-
88106
template <typename T>
89107
void addtensor(const string &name, Tensor<T> &&tensor)
90108
{
91109
if (mem.find(name) != mem.end())
92110
{
93-
throw std::runtime_error("Tensor already exists: " + name);
111+
cerr << "Tensor already exists: " << name << endl;
112+
return;
94113
}
95114
auto ptr = std::make_shared<Tensor<T>>(std::move(tensor));
96115
mem[name] = ptr;
@@ -101,7 +120,8 @@ namespace deepx::mem
101120
{
102121
if (mem.find(name) != mem.end())
103122
{
104-
throw std::runtime_error("Tensor already exists: " + name);
123+
cerr << "Tensor already exists: " << name << endl;
124+
return;
105125
}
106126
auto ptr = std::make_shared<Tensor<T>>(tensor);
107127
mem[name] = ptr;
@@ -111,43 +131,42 @@ namespace deepx::mem
111131
shared_ptr<Tensor<T>> temptensor(vector<int> shape)
112132
{
113133
// 直接构造到shared_ptr避免移动
114-
auto temp = tensorfunc::New<T>(shape); // 临时对象
115-
auto cloned = make_shared<Tensor<T>>(std::move(temp));
116-
mem["temp"+to_string(tempidx)]=cloned;
134+
auto temp = tensorfunc::New<T>(shape); // 临时对象
135+
auto cloned = make_shared<Tensor<T>>(std::move(temp));
136+
mem["temp" + to_string(tempidx)] = cloned;
117137
tempidx++;
118138
return cloned;
119139
}
120-
140+
121141
bool existstensor(const string &name) const
122142
{
123143
return mem.find(name) != mem.end();
124-
}
144+
}
125145

126-
template <typename T>
146+
template <typename T>
127147
shared_ptr<Tensor<T>> gettensor(const string &name) const
128148
{
129-
auto ptr = mem.at(name);
149+
auto ptr = mem.at(name);
130150
return std::static_pointer_cast<Tensor<T>>(ptr);
131151
}
132152

133-
134153
// 获取多个张量
135154
template <typename T>
136155
vector<Tensor<T> *> gettensors(const vector<string> &names) const
137156
{
138157
std::vector<Tensor<T> *> tensors;
139-
try
158+
159+
for (const auto &name : names)
140160
{
141-
for (const auto &name : names)
161+
if (mem.find(name) == mem.end())
142162
{
143-
auto ptr = mem.at(name);
144-
tensors.push_back(std::static_pointer_cast<Tensor<T>>(ptr).get());
163+
cerr << "Tensor not found: " << name << endl;
164+
continue;
145165
}
166+
auto ptr = mem.at(name);
167+
tensors.push_back(std::static_pointer_cast<Tensor<T>>(ptr).get());
146168
}
147-
catch (const std::out_of_range &)
148-
{
149-
throw std::runtime_error("Type mismatch or tensor not found");
150-
}
169+
151170
return tensors;
152171
}
153172

0 commit comments

Comments
 (0)