# Copyright The OpenTelemetry Authors
# SPDX-License-Identifier: Apache-2.0
"""
This library allows tracing PostgreSQL queries made by the
`asyncpg <https://magicstack.github.io/asyncpg/current/>`_ library.
Usage
-----
Start PostgreSQL:
::
docker run -e POSTGRES_USER=user -e POSTGRES_PASSWORD=password -e POSTGRES_DATABASE=database -p 5432:5432 postgres
Run instrumented code:
.. code-block:: python
import asyncio
import asyncpg
from opentelemetry.instrumentation.asyncpg import AsyncPGInstrumentor
# You can optionally pass a custom TracerProvider to AsyncPGInstrumentor.instrument()
AsyncPGInstrumentor().instrument()
async def main():
conn = await asyncpg.connect(user='user', password='password')
await conn.fetch('''SELECT 42;''')
await conn.close()
asyncio.run(main())
API
---
"""
import re
from typing import Collection
import asyncpg
import wrapt
from opentelemetry import trace
from opentelemetry.instrumentation.asyncpg.package import _instruments
from opentelemetry.instrumentation.asyncpg.version import __version__
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.instrumentation.utils import unwrap
from opentelemetry.semconv._incubating.attributes.db_attributes import (
DB_NAME,
DB_STATEMENT,
DB_SYSTEM,
DB_USER,
DbSystemValues,
)
from opentelemetry.semconv._incubating.attributes.net_attributes import (
NET_PEER_NAME,
NET_PEER_PORT,
NET_TRANSPORT,
NetTransportValues,
)
from opentelemetry.trace import SpanKind
from opentelemetry.trace.status import Status, StatusCode
def _hydrate_span_from_args(connection, query, parameters) -> dict:
"""Get network and database attributes from connection."""
span_attributes = {DB_SYSTEM: DbSystemValues.POSTGRESQL.value}
# connection contains _params attribute which is a namedtuple ConnectionParameters.
# https://github.com/MagicStack/asyncpg/blob/master/asyncpg/connection.py#L68
params = getattr(connection, "_params", None)
dbname = getattr(params, "database", None)
if dbname:
span_attributes[DB_NAME] = dbname
user = getattr(params, "user", None)
if user:
span_attributes[DB_USER] = user
# connection contains _addr attribute which is either a host/port tuple, or unix socket string
# https://magicstack.github.io/asyncpg/current/_modules/asyncpg/connection.html
addr = getattr(connection, "_addr", None)
if isinstance(addr, tuple):
span_attributes[NET_PEER_NAME] = addr[0]
span_attributes[NET_PEER_PORT] = addr[1]
span_attributes[NET_TRANSPORT] = NetTransportValues.IP_TCP.value
elif isinstance(addr, str):
span_attributes[NET_PEER_NAME] = addr
span_attributes[NET_TRANSPORT] = NetTransportValues.OTHER.value
if query is not None:
span_attributes[DB_STATEMENT] = query
if parameters is not None and len(parameters) > 0:
span_attributes["db.statement.parameters"] = str(parameters)
return span_attributes
[docs]class AsyncPGInstrumentor(BaseInstrumentor):
_leading_comment_remover = re.compile(r"^/\*.*?\*/")
_tracer = None
def __init__(self, capture_parameters=False):
super().__init__()
self.capture_parameters = capture_parameters
[docs] def instrumentation_dependencies(self) -> Collection[str]:
return _instruments
def _instrument(self, **kwargs):
tracer_provider = kwargs.get("tracer_provider")
self._tracer = trace.get_tracer(
__name__,
__version__,
tracer_provider,
schema_url="https://opentelemetry.io/schemas/1.11.0",
)
for method in [
"Connection.execute",
"Connection.executemany",
"Connection.fetch",
"Connection.fetchval",
"Connection.fetchrow",
]:
wrapt.wrap_function_wrapper(
"asyncpg.connection", method, self._do_execute
)
for method in [
"Cursor.fetch",
"Cursor.forward",
"Cursor.fetchrow",
"CursorIterator.__anext__",
]:
wrapt.wrap_function_wrapper(
"asyncpg.cursor", method, self._do_cursor_execute
)
def _uninstrument(self, **__):
for cls, methods in [
(
asyncpg.connection.Connection,
("execute", "executemany", "fetch", "fetchval", "fetchrow"),
),
(asyncpg.cursor.Cursor, ("forward", "fetch", "fetchrow")),
(asyncpg.cursor.CursorIterator, ("__anext__",)),
]:
for method_name in methods:
unwrap(cls, method_name)
async def _do_execute(self, func, instance, args, kwargs):
exception = None
params = getattr(instance, "_params", None)
name = (
args[0] if args[0] else getattr(params, "database", "postgresql")
)
try:
# Strip leading comments so we get the operation name.
name = self._leading_comment_remover.sub("", name).split()[0]
except IndexError:
name = ""
# Hydrate attributes before span creation to enable filtering
span_attributes = _hydrate_span_from_args(
instance,
args[0],
args[1:] if self.capture_parameters else None,
)
with self._tracer.start_as_current_span(
name, kind=SpanKind.CLIENT, attributes=span_attributes
) as span:
try:
result = await func(*args, **kwargs)
except Exception as exc: # pylint: disable=W0703
exception = exc
raise
finally:
if span.is_recording() and exception is not None:
span.set_status(Status(StatusCode.ERROR))
return result
async def _do_cursor_execute(self, func, instance, args, kwargs):
"""Wrap cursor based functions. For every call this will generate a new span."""
exception = None
params = getattr(instance._connection, "_params", None)
name = (
instance._query
if instance._query
else getattr(params, "database", "postgresql")
)
try:
# Strip leading comments so we get the operation name.
name = self._leading_comment_remover.sub("", name).split()[0]
except IndexError:
name = ""
# Hydrate attributes before span creation to enable filtering
span_attributes = _hydrate_span_from_args(
instance._connection,
instance._query,
instance._args if self.capture_parameters else None,
)
stop = False
with self._tracer.start_as_current_span(
f"CURSOR: {name}",
kind=SpanKind.CLIENT,
attributes=span_attributes,
) as span:
try:
result = await func(*args, **kwargs)
except StopAsyncIteration:
# Do not show this exception to the span
stop = True
except Exception as exc: # pylint: disable=W0703
exception = exc
raise
finally:
if span.is_recording() and exception is not None:
span.set_status(Status(StatusCode.ERROR))
if not stop:
return result
raise StopAsyncIteration