@@ -207,35 +207,11 @@ struct ggml_backend_rpc_buffer_type_context {
207207 size_t max_size;
208208};
209209
210- struct graph_cache {
211-
212- bool is_cached (const ggml_cgraph * cgraph) {
213- if ((int )last_graph.size () != cgraph->n_nodes ) {
214- return false ;
215- }
216- for (int i = 0 ; i < cgraph->n_nodes ; i++) {
217- if (memcmp (&last_graph[i], cgraph->nodes [i], sizeof (ggml_tensor)) != 0 ) {
218- return false ;
219- }
220- }
221- return true ;
222- }
223-
224- void add (const ggml_cgraph * cgraph) {
225- last_graph.resize (cgraph->n_nodes );
226- for (int i = 0 ; i < cgraph->n_nodes ; i++) {
227- memcpy (&last_graph[i], cgraph->nodes [i], sizeof (ggml_tensor));
228- }
229- }
230-
231- std::vector<ggml_tensor> last_graph;
232- };
233-
234210struct ggml_backend_rpc_context {
235211 std::string endpoint;
236212 uint32_t device;
237213 std::string name;
238- graph_cache gc ;
214+ uint64_t last_graph_uid ;
239215};
240216
241217struct ggml_backend_rpc_buffer_context {
@@ -717,15 +693,15 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g
717693 ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context ;
718694
719695 GGML_ASSERT (cgraph->n_nodes > 0 );
720- bool reuse = rpc_ctx->gc . is_cached ( cgraph) ;
696+ bool reuse = cgraph-> uid != 0 && rpc_ctx->last_graph_uid == cgraph-> uid ;
721697 if (reuse) {
722698 rpc_msg_graph_recompute_req request;
723699 request.device = rpc_ctx->device ;
724700 auto sock = get_socket (rpc_ctx->endpoint );
725701 bool status = send_rpc_cmd (sock, RPC_CMD_GRAPH_RECOMPUTE , &request, sizeof (request));
726702 RPC_STATUS_ASSERT (status);
727703 } else {
728- rpc_ctx->gc . add ( cgraph) ;
704+ rpc_ctx->last_graph_uid = cgraph-> uid ;
729705 std::vector<uint8_t > input;
730706 serialize_graph (rpc_ctx->device , cgraph, input);
731707 auto sock = get_socket (rpc_ctx->endpoint );
@@ -791,10 +767,10 @@ ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, u
791767ggml_backend_t ggml_backend_rpc_init (const char * endpoint, uint32_t device) {
792768 std::string dev_name = " RPC" + std::to_string (device) + " [" + std::string (endpoint) + " ]" ;
793769 ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
794- /* .endpoint = */ endpoint,
795- /* .device = */ device,
796- /* .name = */ dev_name,
797- /* .gc = */ {} ,
770+ /* .endpoint = */ endpoint,
771+ /* .device = */ device,
772+ /* .name = */ dev_name,
773+ /* .last_graph_uid = */ 0 ,
798774 };
799775 auto reg = ggml_backend_rpc_add_server (endpoint);
800776 ggml_backend_t backend = new ggml_backend {
0 commit comments