Skip to content

Commit 664f9cd

Browse files
committed
refactor: unify sync and async callable invocation
1 parent 631b337 commit 664f9cd

File tree

6 files changed

+47
-89
lines changed

6 files changed

+47
-89
lines changed

pyrogram/dispatcher.py

Lines changed: 11 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
import asyncio
2020
from collections import OrderedDict
21-
import inspect
2221
import logging
2322
from typing import Dict
2423

@@ -262,12 +261,7 @@ async def deleted_business_messages_parser(update, users, chats):
262261
async def start(self):
263262
if callable(self.client.start_handler):
264263
try:
265-
if inspect.iscoroutinefunction(self.client.start_handler):
266-
await self.client.start_handler(self.client)
267-
else:
268-
result = self.client.start_handler(self.client)
269-
if inspect.isawaitable(result):
270-
await result
264+
await utils.invoke_callable(self.client.start_handler, self.client)
271265
except Exception as e:
272266
log.exception("start_handler raised: %s", e)
273267

@@ -325,12 +319,7 @@ async def stop(self, clear_handlers: bool = True):
325319

326320
if callable(self.client.stop_handler):
327321
try:
328-
if inspect.iscoroutinefunction(self.client.stop_handler):
329-
await self.client.stop_handler(self.client)
330-
else:
331-
result = self.client.stop_handler(self.client)
332-
if inspect.isawaitable(result):
333-
await result
322+
await utils.invoke_callable(self.client.stop_handler, self.client)
334323
except Exception as e:
335324
log.exception("stop_handler raised: %s", e)
336325

@@ -437,15 +426,10 @@ async def handler_worker(self, lock):
437426
continue
438427

439428
try:
440-
if inspect.iscoroutinefunction(handler.callback):
441-
await handler.callback(self.client, *args)
442-
else:
443-
await self.client.loop.run_in_executor(
444-
self.client.executor,
445-
handler.callback,
446-
self.client,
447-
*args
448-
)
429+
await utils.invoke_callable(
430+
handler.callback, self.client, *args,
431+
executor=self.client.executor, loop=self.client.loop
432+
)
449433
except asyncio.CancelledError:
450434
raise
451435
except pyrogram.StopPropagation:
@@ -482,15 +466,11 @@ async def handle_update_handler_exception(
482466
continue
483467

484468
try:
485-
if inspect.iscoroutinefunction(handler.callback):
486-
await handler.callback(
487-
self.client, exc, update_handler, update, users, chats
488-
)
489-
else:
490-
await self.client.loop.run_in_executor(
491-
self.client.executor, handler.callback,
492-
self.client, exc, update_handler, update, users, chats
493-
)
469+
await utils.invoke_callable(
470+
handler.callback,
471+
self.client, exc, update_handler, update, users, chats,
472+
executor=self.client.executor, loop=self.client.loop
473+
)
494474
except pyrogram.StopPropagation:
495475
handled = True
496476
raise

pyrogram/handlers/callback_query_handler.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
1818

1919
import asyncio
20-
from asyncio import iscoroutinefunction
2120
from typing import Callable, Tuple
2221

2322
import pyrogram
2423

24+
from pyrogram import utils
2525
from pyrogram.utils import PyromodConfig
2626
from pyrogram.types import CallbackQuery, Identifier, Listener
2727

@@ -108,12 +108,10 @@ async def check_if_has_matching_listener(
108108
if listener:
109109
filters = listener.filters
110110
if callable(filters):
111-
if iscoroutinefunction(filters.__call__):
112-
listener_does_match = await filters(client, query)
113-
else:
114-
listener_does_match = await client.loop.run_in_executor(
115-
client.executor, filters, client, query
116-
)
111+
listener_does_match = await utils.invoke_callable(
112+
filters, client, query,
113+
executor=client.executor, loop=client.loop
114+
)
117115
else:
118116
listener_does_match = True
119117

@@ -135,12 +133,10 @@ async def check(self, client: "pyrogram.Client", query: CallbackQuery):
135133
query._matched_listener = listener if listener_does_match else None
136134

137135
if callable(self.filters):
138-
if iscoroutinefunction(self.filters.__call__):
139-
handler_does_match = await self.filters(client, query)
140-
else:
141-
handler_does_match = await client.loop.run_in_executor(
142-
client.executor, self.filters, client, query
143-
)
136+
handler_does_match = await utils.invoke_callable(
137+
self.filters, client, query,
138+
executor=client.executor, loop=client.loop
139+
)
144140
else:
145141
handler_does_match = True
146142

@@ -198,12 +194,8 @@ async def resolve_future_or_callback(
198194
raise pyrogram.StopPropagation
199195
elif listener.callback:
200196
try:
201-
if iscoroutinefunction(listener.callback):
202-
await listener.callback(client, query, *args)
203-
else:
204-
listener.callback(client, query, *args)
197+
await utils.invoke_callable(listener.callback, client, query, *args)
205198
except asyncio.CancelledError:
206-
# Cancelled during shutdown/interruption
207199
raise pyrogram.StopPropagation
208200

209201
raise pyrogram.StopPropagation

pyrogram/handlers/conversation_handler.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
# You should have received a copy of the GNU Lesser General Public License
1717
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
1818

19-
import inspect
2019
from typing import Union
2120

2221
import pyrogram
22+
from pyrogram import utils
2323
from pyrogram.types import Message, CallbackQuery
2424
from .message_handler import MessageHandler
2525
from .callback_query_handler import CallbackQueryHandler
@@ -57,21 +57,11 @@ async def check(self, client: "pyrogram.Client", update: Union[Message, Callback
5757

5858
filters = waiter.get('filters')
5959
if callable(filters):
60-
is_async = (
61-
inspect.iscoroutinefunction(filters)
62-
or inspect.iscoroutinefunction(getattr(filters, "__call__", None))
60+
filtered = await utils.invoke_callable(
61+
filters, client, update,
62+
executor=client.executor, loop=client.loop
6363
)
6464

65-
if is_async:
66-
result = filters(client, update)
67-
filtered = await result if inspect.isawaitable(result) else result
68-
else:
69-
filtered = await client.loop.run_in_executor(
70-
client.executor,
71-
filters,
72-
client, update
73-
)
74-
7565
if not filtered or waiter['future'].done():
7666
return False
7767

pyrogram/handlers/handler.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616
# You should have received a copy of the GNU Lesser General Public License
1717
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
1818

19-
import inspect
2019
import logging
2120
from typing import Callable, Optional
2221

2322
import pyrogram
23+
from pyrogram import utils
2424
from pyrogram.filters import Filter
2525
from pyrogram.types import Update
2626

@@ -34,14 +34,10 @@ def __init__(self, callback: Callable, filters: Optional[Filter] = None):
3434

3535
async def check(self, client: "pyrogram.Client", update: Update):
3636
if callable(self.filters):
37-
if inspect.iscoroutinefunction(self.filters.__call__):
38-
return await self.filters(client, update)
39-
else:
40-
return await client.loop.run_in_executor(
41-
client.executor,
42-
self.filters,
43-
client, update
44-
)
37+
return await utils.invoke_callable(
38+
self.filters, client, update,
39+
executor=client.executor, loop=client.loop
40+
)
4541

4642
if self.filters is not None:
4743
log.warning("Non-callable filter %r treated as match-all", self.filters)

pyrogram/handlers/message_handler.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
1818

1919
import asyncio
20-
from inspect import iscoroutinefunction
2120
from typing import Callable
2221
import pyrogram
2322

23+
from pyrogram import utils
2424
from pyrogram.types import Message, Identifier
2525

2626
from .handler import Handler
@@ -85,12 +85,10 @@ async def check_if_has_matching_listener(self, client: "pyrogram.Client", messag
8585
if listener:
8686
filters = listener.filters
8787
if callable(filters):
88-
if iscoroutinefunction(filters.__call__):
89-
listener_does_match = await filters(client, message)
90-
else:
91-
listener_does_match = await client.loop.run_in_executor(
92-
client.executor, filters, client, message
93-
)
88+
listener_does_match = await utils.invoke_callable(
89+
filters, client, message,
90+
executor=client.executor, loop=client.loop
91+
)
9492
else:
9593
listener_does_match = True
9694

@@ -109,12 +107,10 @@ async def check(self, client: "pyrogram.Client", message: Message):
109107
message._matched_listener = listener if listener_does_match else None
110108

111109
if callable(self.filters):
112-
if iscoroutinefunction(self.filters.__call__):
113-
handler_does_match = await self.filters(client, message)
114-
else:
115-
handler_does_match = await client.loop.run_in_executor(
116-
client.executor, self.filters, client, message
117-
)
110+
handler_does_match = await utils.invoke_callable(
111+
self.filters, client, message,
112+
executor=client.executor, loop=client.loop
113+
)
118114
else:
119115
handler_does_match = True
120116

@@ -146,10 +142,7 @@ async def resolve_future_or_callback(self, client: "pyrogram.Client", message: M
146142

147143
raise pyrogram.StopPropagation
148144
elif listener.callback:
149-
if iscoroutinefunction(listener.callback):
150-
await listener.callback(client, message, *args)
151-
else:
152-
listener.callback(client, message, *args)
145+
await utils.invoke_callable(listener.callback, client, message, *args)
153146

154147
raise pyrogram.StopPropagation
155148
else:

pyrogram/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import asyncio
2020
import base64
21+
import inspect
2122
import logging
2223
import functools
2324
import hashlib
@@ -39,6 +40,12 @@
3940
log = logging.getLogger(__name__)
4041

4142

43+
async def invoke_callable(func, *args, executor=None, loop=None):
44+
if inspect.iscoroutinefunction(func) or inspect.iscoroutinefunction(getattr(func, "__call__", None)):
45+
return await func(*args)
46+
return await (loop or asyncio.get_running_loop()).run_in_executor(executor, func, *args)
47+
48+
4249
def get_event_loop() -> asyncio.AbstractEventLoop:
4350
try:
4451
return asyncio.get_running_loop()

0 commit comments

Comments
 (0)