Source code for opentelemetry.instrumentation.redis

# Copyright The OpenTelemetry Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Instrument `redis`_ to report Redis queries.

There are two options for instrumenting code. The first option is to use the
``opentelemetry-instrument`` executable which will automatically
instrument your Redis client. The second is to programmatically enable
instrumentation via the following code:

.. _redis: https://pypi.org/project/redis/

Usage
-----

.. code:: python

    from opentelemetry.instrumentation.redis import RedisInstrumentor
    import redis


    # Instrument redis
    RedisInstrumentor().instrument()

    # This will report a span with the default settings
    client = redis.StrictRedis(host="localhost", port=6379)
    client.get("my-key")

Async Redis clients (i.e. redis.asyncio.Redis) are also instrumented in the same way:

.. code:: python

    from opentelemetry.instrumentation.redis import RedisInstrumentor
    import redis.asyncio


    # Instrument redis
    RedisInstrumentor().instrument()

    # This will report a span with the default settings
    async def redis_get():
        client = redis.asyncio.Redis(host="localhost", port=6379)
        await client.get("my-key")

The `instrument` method accepts the following keyword args:

tracer_provider (TracerProvider) - an optional tracer provider

request_hook (Callable) - a function with extra user-defined logic to be performed before performing the request
this function signature is:  def request_hook(span: Span, instance: redis.connection.Connection, args, kwargs) -> None

response_hook (Callable) - a function with extra user-defined logic to be performed after performing the request
this function signature is: def response_hook(span: Span, instance: redis.connection.Connection, response) -> None

for example:

.. code: python

    from opentelemetry.instrumentation.redis import RedisInstrumentor
    import redis

    def request_hook(span, instance, args, kwargs):
        if span and span.is_recording():
            span.set_attribute("custom_user_attribute_from_request_hook", "some-value")

    def response_hook(span, instance, response):
        if span and span.is_recording():
            span.set_attribute("custom_user_attribute_from_response_hook", "some-value")

    # Instrument redis with hooks
    RedisInstrumentor().instrument(request_hook=request_hook, response_hook=response_hook)

    # This will report a span with the default settings and the custom attributes added from the hooks
    client = redis.StrictRedis(host="localhost", port=6379)
    client.get("my-key")


API
---
"""
import typing
from typing import Any, Collection

import redis
from wrapt import wrap_function_wrapper

from opentelemetry import trace
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.instrumentation.redis.package import _instruments
from opentelemetry.instrumentation.redis.util import (
    _extract_conn_attributes,
    _format_command_args,
)
from opentelemetry.instrumentation.redis.version import __version__
from opentelemetry.instrumentation.utils import unwrap
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace import Span

_DEFAULT_SERVICE = "redis"

_RequestHookT = typing.Optional[
    typing.Callable[
        [Span, redis.connection.Connection, typing.List, typing.Dict], None
    ]
]
_ResponseHookT = typing.Optional[
    typing.Callable[[Span, redis.connection.Connection, Any], None]
]

_REDIS_ASYNCIO_VERSION = (4, 2, 0)
if redis.VERSION >= _REDIS_ASYNCIO_VERSION:
    import redis.asyncio

_REDIS_CLUSTER_VERSION = (4, 1, 0)
_REDIS_ASYNCIO_CLUSTER_VERSION = (4, 3, 2)


def _set_connection_attributes(span, conn):
    if not span.is_recording() or not hasattr(conn, "connection_pool"):
        return
    for key, value in _extract_conn_attributes(
        conn.connection_pool.connection_kwargs
    ).items():
        span.set_attribute(key, value)


def _build_span_name(instance, cmd_args):
    if len(cmd_args) > 0 and cmd_args[0]:
        name = cmd_args[0]
    else:
        name = instance.connection_pool.connection_kwargs.get("db", 0)
    return name


def _build_span_meta_data_for_pipeline(instance):
    try:
        command_stack = (
            instance.command_stack
            if hasattr(instance, "command_stack")
            else instance._command_stack
        )

        cmds = [
            _format_command_args(c.args if hasattr(c, "args") else c[0])
            for c in command_stack
        ]
        resource = "\n".join(cmds)

        span_name = " ".join(
            [
                (c.args[0] if hasattr(c, "args") else c[0][0])
                for c in command_stack
            ]
        )
    except (AttributeError, IndexError):
        command_stack = []
        resource = ""
        span_name = ""

    return command_stack, resource, span_name


# pylint: disable=R0915
def _instrument(
    tracer,
    request_hook: _RequestHookT = None,
    response_hook: _ResponseHookT = None,
):
    def _traced_execute_command(func, instance, args, kwargs):
        query = _format_command_args(args)
        name = _build_span_name(instance, args)

        with tracer.start_as_current_span(
            name, kind=trace.SpanKind.CLIENT
        ) as span:
            if span.is_recording():
                span.set_attribute(SpanAttributes.DB_STATEMENT, query)
                _set_connection_attributes(span, instance)
                span.set_attribute("db.redis.args_length", len(args))
            if callable(request_hook):
                request_hook(span, instance, args, kwargs)
            response = func(*args, **kwargs)
            if callable(response_hook):
                response_hook(span, instance, response)
            return response

    def _traced_execute_pipeline(func, instance, args, kwargs):
        (
            command_stack,
            resource,
            span_name,
        ) = _build_span_meta_data_for_pipeline(instance)

        with tracer.start_as_current_span(
            span_name, kind=trace.SpanKind.CLIENT
        ) as span:
            if span.is_recording():
                span.set_attribute(SpanAttributes.DB_STATEMENT, resource)
                _set_connection_attributes(span, instance)
                span.set_attribute(
                    "db.redis.pipeline_length", len(command_stack)
                )
            response = func(*args, **kwargs)
            if callable(response_hook):
                response_hook(span, instance, response)
            return response

    pipeline_class = (
        "BasePipeline" if redis.VERSION < (3, 0, 0) else "Pipeline"
    )
    redis_class = "StrictRedis" if redis.VERSION < (3, 0, 0) else "Redis"

    wrap_function_wrapper(
        "redis", f"{redis_class}.execute_command", _traced_execute_command
    )
    wrap_function_wrapper(
        "redis.client",
        f"{pipeline_class}.execute",
        _traced_execute_pipeline,
    )
    wrap_function_wrapper(
        "redis.client",
        f"{pipeline_class}.immediate_execute_command",
        _traced_execute_command,
    )
    if redis.VERSION >= _REDIS_CLUSTER_VERSION:
        wrap_function_wrapper(
            "redis.cluster",
            "RedisCluster.execute_command",
            _traced_execute_command,
        )
        wrap_function_wrapper(
            "redis.cluster",
            "ClusterPipeline.execute",
            _traced_execute_pipeline,
        )

    async def _async_traced_execute_command(func, instance, args, kwargs):
        query = _format_command_args(args)
        name = _build_span_name(instance, args)

        with tracer.start_as_current_span(
            name, kind=trace.SpanKind.CLIENT
        ) as span:
            if span.is_recording():
                span.set_attribute(SpanAttributes.DB_STATEMENT, query)
                _set_connection_attributes(span, instance)
                span.set_attribute("db.redis.args_length", len(args))
            if callable(request_hook):
                request_hook(span, instance, args, kwargs)
            response = await func(*args, **kwargs)
            if callable(response_hook):
                response_hook(span, instance, response)
            return response

    async def _async_traced_execute_pipeline(func, instance, args, kwargs):
        (
            command_stack,
            resource,
            span_name,
        ) = _build_span_meta_data_for_pipeline(instance)

        with tracer.start_as_current_span(
            span_name, kind=trace.SpanKind.CLIENT
        ) as span:
            if span.is_recording():
                span.set_attribute(SpanAttributes.DB_STATEMENT, resource)
                _set_connection_attributes(span, instance)
                span.set_attribute(
                    "db.redis.pipeline_length", len(command_stack)
                )
            response = await func(*args, **kwargs)
            if callable(response_hook):
                response_hook(span, instance, response)
            return response

    if redis.VERSION >= _REDIS_ASYNCIO_VERSION:
        wrap_function_wrapper(
            "redis.asyncio",
            f"{redis_class}.execute_command",
            _async_traced_execute_command,
        )
        wrap_function_wrapper(
            "redis.asyncio.client",
            f"{pipeline_class}.execute",
            _async_traced_execute_pipeline,
        )
        wrap_function_wrapper(
            "redis.asyncio.client",
            f"{pipeline_class}.immediate_execute_command",
            _async_traced_execute_command,
        )
    if redis.VERSION >= _REDIS_ASYNCIO_CLUSTER_VERSION:
        wrap_function_wrapper(
            "redis.asyncio.cluster",
            "RedisCluster.execute_command",
            _async_traced_execute_command,
        )
        wrap_function_wrapper(
            "redis.asyncio.cluster",
            "ClusterPipeline.execute",
            _async_traced_execute_pipeline,
        )


[docs]class RedisInstrumentor(BaseInstrumentor): """An instrumentor for Redis See `BaseInstrumentor` """
[docs] def instrumentation_dependencies(self) -> Collection[str]: return _instruments
def _instrument(self, **kwargs): """Instruments the redis module Args: **kwargs: Optional arguments ``tracer_provider``: a TracerProvider, defaults to global. ``response_hook``: An optional callback which is invoked right before the span is finished processing a response. """ tracer_provider = kwargs.get("tracer_provider") tracer = trace.get_tracer( __name__, __version__, tracer_provider=tracer_provider ) _instrument( tracer, request_hook=kwargs.get("request_hook"), response_hook=kwargs.get("response_hook"), ) def _uninstrument(self, **kwargs): if redis.VERSION < (3, 0, 0): unwrap(redis.StrictRedis, "execute_command") unwrap(redis.StrictRedis, "pipeline") unwrap(redis.Redis, "pipeline") unwrap( redis.client.BasePipeline, # pylint:disable=no-member "execute", ) unwrap( redis.client.BasePipeline, # pylint:disable=no-member "immediate_execute_command", ) else: unwrap(redis.Redis, "execute_command") unwrap(redis.Redis, "pipeline") unwrap(redis.client.Pipeline, "execute") unwrap(redis.client.Pipeline, "immediate_execute_command") if redis.VERSION >= _REDIS_CLUSTER_VERSION: unwrap(redis.cluster.RedisCluster, "execute_command") unwrap(redis.cluster.ClusterPipeline, "execute") if redis.VERSION >= _REDIS_ASYNCIO_VERSION: unwrap(redis.asyncio.Redis, "execute_command") unwrap(redis.asyncio.Redis, "pipeline") unwrap(redis.asyncio.client.Pipeline, "execute") unwrap(redis.asyncio.client.Pipeline, "immediate_execute_command") if redis.VERSION >= _REDIS_ASYNCIO_CLUSTER_VERSION: unwrap(redis.asyncio.cluster.RedisCluster, "execute_command") unwrap(redis.asyncio.cluster.ClusterPipeline, "execute")