@@ -18,7 +18,7 @@ namespace deepx::mem
1818 unordered_map<string, std::any> args;
1919
2020 std::unordered_map<std::string, std::shared_ptr<TensorBase>> mem;
21- int tempidx= 0 ;
21+ int tempidx = 0 ;
2222
2323 // template <typename T>
2424 // static std::shared_ptr<void> type_erase(const std::shared_ptr<Tensor<T>> &ptr)
@@ -60,37 +60,56 @@ namespace deepx::mem
6060 template <typename T>
6161 void addarg (const string &name, const T value)
6262 {
63+ if (args.find (name) != args.end ())
64+ {
65+ cerr << " Argument already exists: " << name << endl;
66+ }
6367 args[name] = value;
6468 }
6569
6670 template <typename T>
6771 T getarg (const string &name) const
6872 {
73+ if (args.find (name) == args.end ())
74+ {
75+ cerr << " Argument not found: " << name << endl;
76+ return T ();
77+ }
6978 return any_cast<T>(args.at (name));
7079 }
7180
7281 template <typename T>
7382 void addvector (const string &name, const vector<T> &value)
7483 {
84+ if (args.find (name) != args.end ())
85+ {
86+ cerr << " Vector already exists: " << name << endl;
87+ return ;
88+ }
7589 args[name] = value;
7690 }
7791
7892 template <typename T>
7993 vector<T> getvector (const string &name) const
8094 {
95+ if (args.find (name) == args.end ())
96+ {
97+ cerr << " Vector not found: " << name << endl;
98+ return vector<T>();
99+ }
81100 auto v = any_cast<vector<T>>(args.at (name));
82101 return v;
83102 }
84103
104+ // tensor
85105
86- // tensor
87-
88106 template <typename T>
89107 void addtensor (const string &name, Tensor<T> &&tensor)
90108 {
91109 if (mem.find (name) != mem.end ())
92110 {
93- throw std::runtime_error (" Tensor already exists: " + name);
111+ cerr << " Tensor already exists: " << name << endl;
112+ return ;
94113 }
95114 auto ptr = std::make_shared<Tensor<T>>(std::move (tensor));
96115 mem[name] = ptr;
@@ -101,7 +120,8 @@ namespace deepx::mem
101120 {
102121 if (mem.find (name) != mem.end ())
103122 {
104- throw std::runtime_error (" Tensor already exists: " + name);
123+ cerr << " Tensor already exists: " << name << endl;
124+ return ;
105125 }
106126 auto ptr = std::make_shared<Tensor<T>>(tensor);
107127 mem[name] = ptr;
@@ -111,43 +131,42 @@ namespace deepx::mem
111131 shared_ptr<Tensor<T>> temptensor (vector<int > shape)
112132 {
113133 // 直接构造到shared_ptr避免移动
114- auto temp = tensorfunc::New<T>(shape); // 临时对象
115- auto cloned = make_shared<Tensor<T>>(std::move (temp));
116- mem[" temp" + to_string (tempidx)]= cloned;
134+ auto temp = tensorfunc::New<T>(shape); // 临时对象
135+ auto cloned = make_shared<Tensor<T>>(std::move (temp));
136+ mem[" temp" + to_string (tempidx)] = cloned;
117137 tempidx++;
118138 return cloned;
119139 }
120-
140+
121141 bool existstensor (const string &name) const
122142 {
123143 return mem.find (name) != mem.end ();
124- }
144+ }
125145
126- template <typename T>
146+ template <typename T>
127147 shared_ptr<Tensor<T>> gettensor (const string &name) const
128148 {
129- auto ptr = mem.at (name);
149+ auto ptr = mem.at (name);
130150 return std::static_pointer_cast<Tensor<T>>(ptr);
131151 }
132152
133-
134153 // 获取多个张量
135154 template <typename T>
136155 vector<Tensor<T> *> gettensors (const vector<string> &names) const
137156 {
138157 std::vector<Tensor<T> *> tensors;
139- try
158+
159+ for (const auto &name : names)
140160 {
141- for ( const auto & name : names )
161+ if (mem. find ( name) == mem. end () )
142162 {
143- auto ptr = mem. at ( name) ;
144- tensors. push_back (std::static_pointer_cast<Tensor<T>>(ptr). get ()) ;
163+ cerr << " Tensor not found: " << name << endl ;
164+ continue ;
145165 }
166+ auto ptr = mem.at (name);
167+ tensors.push_back (std::static_pointer_cast<Tensor<T>>(ptr).get ());
146168 }
147- catch (const std::out_of_range &)
148- {
149- throw std::runtime_error (" Type mismatch or tensor not found" );
150- }
169+
151170 return tensors;
152171 }
153172
0 commit comments