Skip to content

Commit d5003b6

Browse files
authored
rpc : use graph uid instead of graph cache (ggml-org#22701)
Store the last graph uid and compare against it to determine if the same graph is being computed.
1 parent 2635ac7 commit d5003b6

1 file changed

Lines changed: 7 additions & 31 deletions

File tree

ggml/src/ggml-rpc/ggml-rpc.cpp

Lines changed: 7 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -207,35 +207,11 @@ struct ggml_backend_rpc_buffer_type_context {
207207
size_t max_size;
208208
};
209209

210-
struct graph_cache {
211-
212-
bool is_cached(const ggml_cgraph * cgraph) {
213-
if ((int)last_graph.size() != cgraph->n_nodes) {
214-
return false;
215-
}
216-
for (int i = 0; i < cgraph->n_nodes; i++) {
217-
if (memcmp(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor)) != 0) {
218-
return false;
219-
}
220-
}
221-
return true;
222-
}
223-
224-
void add(const ggml_cgraph * cgraph) {
225-
last_graph.resize(cgraph->n_nodes);
226-
for (int i = 0; i < cgraph->n_nodes; i++) {
227-
memcpy(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor));
228-
}
229-
}
230-
231-
std::vector<ggml_tensor> last_graph;
232-
};
233-
234210
struct ggml_backend_rpc_context {
235211
std::string endpoint;
236212
uint32_t device;
237213
std::string name;
238-
graph_cache gc;
214+
uint64_t last_graph_uid;
239215
};
240216

241217
struct ggml_backend_rpc_buffer_context {
@@ -717,15 +693,15 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g
717693
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
718694

719695
GGML_ASSERT(cgraph->n_nodes > 0);
720-
bool reuse = rpc_ctx->gc.is_cached(cgraph);
696+
bool reuse = cgraph->uid != 0 && rpc_ctx->last_graph_uid == cgraph->uid;
721697
if (reuse) {
722698
rpc_msg_graph_recompute_req request;
723699
request.device = rpc_ctx->device;
724700
auto sock = get_socket(rpc_ctx->endpoint);
725701
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request));
726702
RPC_STATUS_ASSERT(status);
727703
} else {
728-
rpc_ctx->gc.add(cgraph);
704+
rpc_ctx->last_graph_uid = cgraph->uid;
729705
std::vector<uint8_t> input;
730706
serialize_graph(rpc_ctx->device, cgraph, input);
731707
auto sock = get_socket(rpc_ctx->endpoint);
@@ -791,10 +767,10 @@ ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, u
791767
ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) {
792768
std::string dev_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]";
793769
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
794-
/* .endpoint = */ endpoint,
795-
/* .device = */ device,
796-
/* .name = */ dev_name,
797-
/* .gc = */ {},
770+
/* .endpoint = */ endpoint,
771+
/* .device = */ device,
772+
/* .name = */ dev_name,
773+
/* .last_graph_uid = */ 0,
798774
};
799775
auto reg = ggml_backend_rpc_add_server(endpoint);
800776
ggml_backend_t backend = new ggml_backend {

0 commit comments

Comments
 (0)