# Copyright The OpenTelemetry Authors
# SPDX-License-Identifier: Apache-2.0
"""
Instrument `celery`_ to trace Celery applications.
.. _celery: https://pypi.org/project/celery/
Usage
-----
* Start broker backend
.. code::
docker run -p 5672:5672 rabbitmq
* Run instrumented task
.. code:: python
from opentelemetry.instrumentation.celery import CeleryInstrumentor
from celery import Celery
from celery.signals import worker_process_init
@worker_process_init.connect(weak=False)
def init_celery_tracing(*args, **kwargs):
CeleryInstrumentor().instrument()
app = Celery("tasks", broker="amqp://localhost")
@app.task
def add(x, y):
return x + y
add.delay(42, 50)
Setting up tracing
------------------
When tracing a celery worker process, tracing and instrumentation both must be initialized after the celery worker
process is initialized. This is required for any tracing components that might use threading to work correctly
such as the BatchSpanProcessor. Celery provides a signal called ``worker_process_init`` that can be used to
accomplish this as shown in the example above.
API
---
"""
from __future__ import annotations
import logging
from collections.abc import Collection, Iterable
from timeit import default_timer
from billiard import VERSION
from billiard.einfo import ExceptionInfo
from celery import signals # pylint: disable=no-name-in-module
from celery.worker.request import Request # pylint: disable=no-name-in-module
from opentelemetry import context as context_api
from opentelemetry import trace
from opentelemetry.instrumentation.celery import utils
from opentelemetry.instrumentation.celery.package import _instruments
from opentelemetry.instrumentation.celery.version import __version__
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.metrics import get_meter
from opentelemetry.propagate import extract, inject
from opentelemetry.propagators.textmap import Getter
from opentelemetry.semconv._incubating.attributes.messaging_attributes import (
MESSAGING_MESSAGE_ID,
)
from opentelemetry.trace.status import Status, StatusCode
if VERSION >= (4, 0, 1):
from billiard.einfo import ExceptionWithTraceback
else:
ExceptionWithTraceback = None
logger = logging.getLogger(__name__)
# Task operations
_TASK_TAG_KEY = "celery.action"
_TASK_APPLY_ASYNC = "apply_async"
_TASK_RUN = "run"
_TASK_RETRY_REASON_KEY = "celery.retry.reason"
_TASK_REVOKED_REASON_KEY = "celery.revoked.reason"
_TASK_REVOKED_TERMINATED_SIGNAL_KEY = "celery.terminated.signal"
_TASK_NAME_KEY = "celery.task_name"
[docs]class CeleryGetter(Getter[Request]):
[docs] def get(self, carrier: Request, key: str) -> list[str] | None:
value = getattr(carrier, key, None)
if value is None:
return None
# Celery's Context copies all message properties as instance
# attributes, including non-string values like timelimit (tuple
# of ints). The TextMapPropagator contract requires string
# values, so coerce anything that isn't already a string.
if isinstance(value, str):
return [value]
if isinstance(value, Iterable):
return [str(v) if not isinstance(v, str) else v for v in value]
return [str(value)]
[docs] def keys(self, carrier: Request) -> list[str]:
return []
celery_getter = CeleryGetter()
[docs]class CeleryInstrumentor(BaseInstrumentor):
def __init__(self):
super().__init__()
if not hasattr(self, "metrics"):
self.metrics = None
if not hasattr(self, "task_id_to_start_time"):
self.task_id_to_start_time = {}
[docs] def instrumentation_dependencies(self) -> Collection[str]:
return _instruments
def _instrument(self, **kwargs):
tracer_provider = kwargs.get("tracer_provider")
self._tracer = trace.get_tracer(
__name__,
__version__,
tracer_provider,
schema_url="https://opentelemetry.io/schemas/1.11.0",
)
meter_provider = kwargs.get("meter_provider")
meter = get_meter(
__name__,
__version__,
meter_provider,
schema_url="https://opentelemetry.io/schemas/1.11.0",
)
self.task_id_to_start_time = {}
self.create_celery_metrics(meter)
signals.task_prerun.connect(self._trace_prerun, weak=False)
signals.task_postrun.connect(self._trace_postrun, weak=False)
signals.before_task_publish.connect(
self._trace_before_publish, weak=False
)
signals.after_task_publish.connect(
self._trace_after_publish, weak=False
)
signals.task_failure.connect(self._trace_failure, weak=False)
signals.task_retry.connect(self._trace_retry, weak=False)
def _uninstrument(self, **kwargs):
signals.task_prerun.disconnect(self._trace_prerun)
signals.task_postrun.disconnect(self._trace_postrun)
signals.before_task_publish.disconnect(self._trace_before_publish)
signals.after_task_publish.disconnect(self._trace_after_publish)
signals.task_failure.disconnect(self._trace_failure)
signals.task_retry.disconnect(self._trace_retry)
self.task_id_to_start_time = {}
def _trace_prerun(self, *args, **kwargs):
task = utils.retrieve_task(kwargs)
task_id = utils.retrieve_task_id(kwargs)
if task is None or task_id is None:
return
self.update_task_duration_time(task_id)
request = task.request
tracectx = extract(request, getter=celery_getter) or None
token = context_api.attach(tracectx) if tracectx is not None else None
logger.debug("prerun signal start task_id=%s", task_id)
operation_name = f"{_TASK_RUN}/{task.name}"
span = self._tracer.start_span(
operation_name, context=tracectx, kind=trace.SpanKind.CONSUMER
)
activation = trace.use_span(span, end_on_exit=True)
activation.__enter__() # pylint: disable=unnecessary-dunder-call
utils.attach_context(task, task_id, span, activation, token)
def _trace_postrun(self, *args, **kwargs):
task = utils.retrieve_task(kwargs)
task_id = utils.retrieve_task_id(kwargs)
if task is None or task_id is None:
return
logger.debug("postrun signal task_id=%s", task_id)
# retrieve and finish the Span
ctx = utils.retrieve_context(task, task_id)
if ctx is None:
logger.warning("no existing span found for task_id=%s", task_id)
return
span, activation, token = ctx
# request context tags
if span.is_recording():
span.set_attribute(_TASK_TAG_KEY, _TASK_RUN)
utils.set_attributes_from_context(span, kwargs)
utils.set_attributes_from_context(span, task.request)
span.set_attribute(_TASK_NAME_KEY, task.name)
activation.__exit__(None, None, None)
utils.detach_context(task, task_id)
self.update_task_duration_time(task_id)
labels = {"task": task.name, "worker": task.request.hostname}
self._record_histograms(task_id, labels)
self.task_id_to_start_time.pop(task_id, None)
# if the process sending the task is not instrumented
# there's no incoming context and no token to detach
if token is not None:
context_api.detach(token)
def _trace_before_publish(self, *args, **kwargs):
task = utils.retrieve_task_from_sender(kwargs)
task_id = utils.retrieve_task_id_from_message(kwargs)
if task_id is None:
return
if task is None:
# task is an anonymous task send using send_task or using canvas workflow
# Signatures() to send to a task not in the current processes dependency
# tree
task_name = kwargs.get("sender", "unknown")
else:
task_name = task.name
operation_name = f"{_TASK_APPLY_ASYNC}/{task_name}"
span = self._tracer.start_span(
operation_name, kind=trace.SpanKind.PRODUCER
)
# apply some attributes here because most of the data is not available
if span.is_recording():
span.set_attribute(_TASK_TAG_KEY, _TASK_APPLY_ASYNC)
span.set_attribute(MESSAGING_MESSAGE_ID, task_id)
span.set_attribute(_TASK_NAME_KEY, task_name)
utils.set_attributes_from_context(span, kwargs)
activation = trace.use_span(span, end_on_exit=True)
activation.__enter__() # pylint: disable=unnecessary-dunder-call
utils.attach_context(
task, task_id, span, activation, None, is_publish=True
)
headers = kwargs.get("headers")
if headers:
inject(headers)
@staticmethod
def _trace_after_publish(*args, **kwargs):
task = utils.retrieve_task_from_sender(kwargs)
task_id = utils.retrieve_task_id_from_message(kwargs)
if task is None or task_id is None:
return
# retrieve and finish the Span
ctx = utils.retrieve_context(task, task_id, is_publish=True)
if ctx is None:
logger.warning("no existing span found for task_id=%s", task_id)
return
_, activation, _ = ctx
activation.__exit__(None, None, None) # pylint: disable=unnecessary-dunder-call
utils.detach_context(task, task_id, is_publish=True)
@staticmethod
def _trace_failure(*args, **kwargs):
task = utils.retrieve_task_from_sender(kwargs)
task_id = utils.retrieve_task_id(kwargs)
if task is None or task_id is None:
return
ctx = utils.retrieve_context(task, task_id)
if ctx is None:
return
span, _, _ = ctx
if not span.is_recording():
return
status_kwargs = {"status_code": StatusCode.ERROR}
ex = kwargs.get("einfo")
if (
hasattr(task, "throws")
and ex is not None
and isinstance(ex.exception, task.throws)
):
return
if ex is not None:
# Unwrap the actual exception wrapped by billiard's
# `ExceptionInfo` and `ExceptionWithTraceback`.
if isinstance(ex, ExceptionInfo) and ex.exception is not None:
ex = ex.exception
if (
ExceptionWithTraceback is not None
and isinstance(ex, ExceptionWithTraceback)
and ex.exc is not None
):
ex = ex.exc
status_kwargs["description"] = str(ex)
span.record_exception(ex)
span.set_status(Status(**status_kwargs))
@staticmethod
def _trace_retry(*args, **kwargs):
task = utils.retrieve_task_from_sender(kwargs)
task_id = utils.retrieve_task_id_from_request(kwargs)
reason = utils.retrieve_reason(kwargs)
if task is None or task_id is None or reason is None:
return
ctx = utils.retrieve_context(task, task_id)
if ctx is None:
return
span, _, _ = ctx
if not span.is_recording():
return
# Add retry reason metadata to span
# Use `str(reason)` instead of `reason.message` in case we get
# something that isn't an `Exception`
span.set_attribute(_TASK_RETRY_REASON_KEY, str(reason))
[docs] def update_task_duration_time(self, task_id):
cur_time = default_timer()
task_duration_time_until_now = (
cur_time - self.task_id_to_start_time[task_id]
if task_id in self.task_id_to_start_time
else cur_time
)
self.task_id_to_start_time[task_id] = task_duration_time_until_now
def _record_histograms(self, task_id, metric_attributes):
if task_id is None:
return
self.metrics["flower.task.runtime.seconds"].record(
self.task_id_to_start_time.get(task_id),
attributes=metric_attributes,
)
[docs] def create_celery_metrics(self, meter) -> None:
self.metrics = {
"flower.task.runtime.seconds": meter.create_histogram(
name="flower.task.runtime.seconds",
unit="seconds",
description="The time it took to run the task.",
)
}