diff --git a/c_src/erlzmq_nif.c b/c_src/erlzmq_nif.c index a894f9b..ea69b8d 100644 --- a/c_src/erlzmq_nif.c +++ b/c_src/erlzmq_nif.c @@ -28,6 +28,14 @@ #include #include #include +#include + +#if (defined APPLE || defined __APPLE__) +#include +#include +#else +#include +#endif #define ERLZMQ_MAX_CONCURRENT_REQUESTS 16384 @@ -63,6 +71,7 @@ typedef struct erlzmq_socket { typedef struct { int type; + int64_t deadline; union { struct { erlzmq_socket_t * socket; @@ -544,6 +553,52 @@ NIF(erlzmq_nif_getsockopt) } } +// Can't use enif_monotonic_time(ERL_NIF_MSEC) for time calculation, as it only works in +// scheduler threads. The zclock mono code is derived from the code in the czmq library. + +static int64_t zclock_mono() +{ +#if (defined APPLE || defined __APPLE__) + clock_serv_t cclock; + mach_timespec_t mts; + host_get_clock_service (mach_host_self (), SYSTEM_CLOCK, &cclock); + clock_get_time (cclock, &mts); + mach_port_deallocate (mach_task_self (), cclock); + return (int64_t) ((int64_t) mts.tv_sec * 1000 + (int64_t) mts.tv_nsec / 1000000); + +#elif (defined WIN32 || defined _WIN32 || defined WINDOWS || defined _WINDOWS || defined __WINDOWS__) + // System frequency does not change at run-time, cache it + static int64_t frequency = 0; + if (frequency == 0) { + LARGE_INTEGER freq; + QueryPerformanceFrequency (&freq); + // Windows documentation says that XP and later will always return non-zero + assert (freq.QuadPart != 0); + frequency = freq.QuadPart; + } + LARGE_INTEGER count; + QueryPerformanceCounter (&count); + return (int64_t) (count.QuadPart * 1000) / frequency; +#else + struct timespec ts; + clock_gettime (CLOCK_MONOTONIC, &ts); + return (int64_t) ((int64_t) ts.tv_sec * 1000 + (int64_t) ts.tv_nsec / 1000000); +#endif +} + +static void set_socket_deadline(erlzmq_thread_request_t* req, int option_name) +{ + int timeout_ms = 0; + size_t option_len = sizeof(int); + assert(req->data.recv.socket->socket_zmq); + int rc = zmq_getsockopt(req->data.recv.socket->socket_zmq, option_name, &timeout_ms, &option_len); + if (rc != 0) { + fprintf(stderr, "failed to get socket timeout option: %s\n", zmq_strerror(zmq_errno())); + } + int64_t now = zclock_mono(); + req->deadline = timeout_ms > 0 ? now + timeout_ms : 0; +} + NIF(erlzmq_nif_send) { erlzmq_thread_request_t req; @@ -605,6 +660,7 @@ NIF(erlzmq_nif_send) req.data.send.ref = enif_make_ref(req.data.send.env); enif_self(env, &req.data.send.pid); req.data.send.socket = socket; + set_socket_deadline(&req, ZMQ_SNDTIMEO); zmq_msg_t msg; if (zmq_msg_init_size(&msg, sizeof(erlzmq_thread_request_t))) { @@ -700,6 +756,7 @@ NIF(erlzmq_nif_recv) req.data.recv.ref = enif_make_ref(req.data.recv.env); enif_self(env, &req.data.recv.pid); req.data.recv.socket = socket; + set_socket_deadline(&req, ZMQ_RCVTIMEO); if (zmq_msg_init_size(&msg, sizeof(erlzmq_thread_request_t)) == -1) { enif_free_env(req.data.recv.env); @@ -766,6 +823,7 @@ NIF(erlzmq_nif_close) req.data.close.ref = enif_make_ref(req.data.close.env); enif_self(env, &req.data.close.pid); req.data.close.socket = socket; + req.deadline = 0; zmq_msg_t msg; if (zmq_msg_init_size(&msg, sizeof(erlzmq_thread_request_t))) { @@ -830,6 +888,7 @@ NIF(erlzmq_nif_term) erlzmq_thread_request_t req; req.type = ERLZMQ_THREAD_REQUEST_TERM; + req.deadline = 0; req.data.term.env = enif_alloc_env(); req.data.term.ref = enif_make_ref(req.data.term.env); enif_self(env, &req.data.term.pid); @@ -880,6 +939,53 @@ NIF(erlzmq_nif_version) enif_make_int(env, patch)); } +static int64_t remove_expired_items_and_determine_timeout(vector_t *items_zmq, vector_t *requests) { + int64_t deadline = LLONG_MAX; + int64_t now = zclock_mono(); + for (int i = 1; i < vector_count(items_zmq); ++i) { + zmq_pollitem_t * item = vector_get(zmq_pollitem_t, items_zmq, i); + erlzmq_thread_request_t * r = vector_get(erlzmq_thread_request_t, requests, i); + int64_t now = zclock_mono(); + if (r->deadline && r->deadline <= now) { + int remove_item = 0; + if (r->type == ERLZMQ_THREAD_REQUEST_RECV) { + enif_send(NULL, &r->data.recv.pid, r->data.recv.env, + enif_make_tuple2(r->data.recv.env, + enif_make_copy(r->data.recv.env, r->data.recv.ref), + return_zmq_errno(r->data.recv.env, EAGAIN))); + enif_free_env(r->data.recv.env); + enif_release_resource(r->data.recv.socket); + remove_item = 1; + } else if (r->type == ERLZMQ_THREAD_REQUEST_SEND) { + enif_send(NULL, &r->data.send.pid, r->data.send.env, + enif_make_tuple2(r->data.send.env, + enif_make_copy(r->data.send.env, r->data.send.ref), + return_zmq_errno(r->data.send.env, EAGAIN))); + enif_free_env(r->data.send.env); + enif_release_resource(r->data.send.socket); + remove_item = 1; + } + if (remove_item) { + int status = vector_remove(items_zmq, i); + assert(status == 0); + status = vector_remove(requests, i); + assert(status == 0); + --i; + } + } else if (r->deadline && r->deadline < deadline) { + deadline = r->deadline; + } + } + int64_t timeout; + if (deadline == LLONG_MAX) + timeout = -1; + else { + timeout = deadline - now; + assert(timeout > 0); + } + return timeout; +} + static void * polling_thread(void * handle) { erlzmq_context_t * context = (erlzmq_context_t *) handle; @@ -909,8 +1015,9 @@ static void * polling_thread(void * handle) int i; for (;;) { + int64_t timeout = remove_expired_items_and_determine_timeout(&items_zmq, &requests); int count = zmq_poll(vector_p(zmq_pollitem_t, &items_zmq), - vector_count(&items_zmq), -1); + vector_count(&items_zmq), timeout); assert(count != -1); if (vector_get(zmq_pollitem_t, &items_zmq, 0)->revents & ZMQ_POLLIN) { --count; @@ -1240,6 +1347,7 @@ static ERL_NIF_TERM add_active_req(ErlNifEnv* env, erlzmq_socket_t * socket) req.data.recv.flags = 0; enif_self(env, &req.data.recv.pid); req.data.recv.socket = socket; + set_socket_deadline(&req, ZMQ_RCVTIMEO); zmq_msg_t msg; if (zmq_msg_init_size(&msg, sizeof(erlzmq_thread_request_t))) { diff --git a/src/erlzmq.erl b/src/erlzmq.erl index dd479ff..8cb8f23 100644 --- a/src/erlzmq.erl +++ b/src/erlzmq.erl @@ -186,13 +186,6 @@ send({I, Socket}, Binary, Flags) ok; {Ref, {error, _} = Error} -> Error - after case erlzmq_nif:getsockopt(Socket,?'ZMQ_SNDTIMEO') of - {ok, -1} -> - infinity; - {ok, Else} -> - Else - end -> - {error, eagain} end; Result -> Result @@ -251,13 +244,6 @@ recv({I, Socket}, Flags) Error; {Ref, Result} -> {ok, Result} - after case erlzmq_nif:getsockopt(Socket,?'ZMQ_RCVTIMEO') of - {ok, -1} -> - infinity; - {ok, Else} -> - Else - end -> - {error, eagain} end; Result -> Result diff --git a/test/erlzmq_test.erl b/test/erlzmq_test.erl index 2f9e404..561959c 100644 --- a/test/erlzmq_test.erl +++ b/test/erlzmq_test.erl @@ -492,3 +492,50 @@ basic_tests(Transport, Type1, Type2, Mode) -> ok = erlzmq:close(S2), ok = erlzmq:term(C). +recv_timeout_breaks_message_guarantees_test() -> + ?PRINT_START, + {ok, C} = erlzmq:context(1), + {ok, S1} = erlzmq:socket(C, [push, {active, false}]), + {ok, S2} = erlzmq:socket(C, [pull, {active, false}]), + + ok = erlzmq:setsockopt(S2, sndtimeo, 500), + ok = erlzmq:setsockopt(S2, rcvtimeo, 500), + + ok = erlzmq:connect(S1, "tcp://127.0.0.1:5559"), + ok = erlzmq:bind(S2, "tcp://*:5559"), + + ok = erlzmq:send(S1, <<"ABC">>, [sndmore]), + ok = erlzmq:send(S1, <<"DEF">>), + + %% send and receive a multipart message. should be fine. + {ok, Msg1} = erlzmq:recv(S2), + ?assertEqual(<<"ABC">> , Msg1), + {ok, RcvMore1} = erlzmq:getsockopt(S2, rcvmore), + ?assert(RcvMore1 > 0), + {ok, Msg2} = erlzmq:recv(S2), + {ok, RcvMore2} = erlzmq:getsockopt(S2, rcvmore), + ?assertEqual(0, RcvMore2), + ?assertEqual(<<"DEF">>, Msg2), + + %% try to read when no message is there. this should time out. + + ?assertMatch({error, eagain}, erlzmq:recv(S2)), + + %% %% send the next multipart message. + ok = erlzmq:send(S1, <<"GHI">>, [sndmore]), + ok = erlzmq:send(S1, <<"JKL">>), + + %% now receive all parts. + {ok, Msg3} = erlzmq:recv(S2), + ?assertEqual(<<"GHI">> , Msg3), + {ok, RcvMore3} = erlzmq:getsockopt(S2, rcvmore), + ?assert(RcvMore3 > 0), + {ok, Msg4} = erlzmq:recv(S2), + {ok, RcvMore4} = erlzmq:getsockopt(S2, rcvmore), + ?assertEqual(0, RcvMore4), + ?assertEqual(<<"JKL">>, Msg4), + + ok = erlzmq:close(S1), + ok = erlzmq:close(S2), + ok = erlzmq:term(C), + ?PRINT_END.