Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ build/
dist/
*.egg-info/
*.pyc
.venv/
251 changes: 150 additions & 101 deletions cursor_pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,113 @@
from django.utils.translation import gettext_lazy as _


class CursorStrategy:
"""Base interface for cursor pagination strategies."""

def get_ordering(self, ordering, from_last=False):
"""Transform ordering fields for database queries according to the strategy."""
raise NotImplementedError

def build_cursor_filter(
self, ordering, cursor_values, reverse=False, from_last=False
):
"""Build the cursor filter using the strategy's approach."""
raise NotImplementedError


class DefaultCursorStrategy(CursorStrategy):
"""Default strategy maintaining current NULLS LAST behavior."""

def get_ordering(self, ordering, from_last=False):
"""
Transform ordering fields with explicit NULL handling for consistent behavior.

This clarifies that NULL values come at the end in the sort.
When "from_last" is specified, NULL values come first since we return
the results in reversed order.
"""
nulls_ordering = []
for key in ordering:
is_reversed = key.startswith('-')
column = key.lstrip('-')
if is_reversed:
if from_last:
nulls_ordering.append(F(column).desc(nulls_first=True))
else:
nulls_ordering.append(F(column).desc(nulls_last=True))
else:
if from_last:
nulls_ordering.append(F(column).asc(nulls_first=True))
else:
nulls_ordering.append(F(column).asc(nulls_last=True))

return nulls_ordering

def build_cursor_filter(
self, ordering, cursor_values, reverse=False, from_last=False
):
"""
Build the cursor filter using the current OR logic and NULL handling.
This is the existing implementation from the apply_cursor method.
"""
if not ordering or not cursor_values:
return Q()

if len(ordering) != len(cursor_values):
raise ValueError("Ordering and cursor values must match length")

# Convert cursor values for comparison
position_values = [
Value(pos, output_field=TextField()) if pos is not None else None
for pos in cursor_values
]

# Build Q object with OR logic and NULL handling (current implementation)
filtering = Q()
q_equality = {}

for ordering_field, value in zip(ordering, position_values):
is_reversed = ordering_field.startswith('-')
o = ordering_field.lstrip('-')
if value is None: # cursor value for the key was NULL
key = "{}__isnull".format(o)
if (
from_last is True
): # if from_last & cursor value is NULL, we need to get non Null for the key
q = {key: False}
q.update(q_equality)
filtering |= Q(**q)

q_equality.update({key: True})
else: # cursor value for the key was non NULL
if reverse != is_reversed:
comparison_key = "{}__lt".format(o)
else:
comparison_key = "{}__gt".format(o)

q = Q(**{comparison_key: value})
if not from_last: # if not from_last, NULL values are still candidates
q |= Q(**{"{}__isnull".format(o): True})
filtering |= (q) & Q(**q_equality)

equality_key = "{}__exact".format(o)
q_equality.update({equality_key: value})

return filtering


class PreserveOrderingStrategy(DefaultCursorStrategy):
"""
Cursor strategy that preserves the ordering of the fields.
"""

def get_ordering(self, ordering, from_last=False):
"""
Return simple ordering fields.
"""
return ordering


class InvalidCursor(Exception):
pass

Expand All @@ -14,6 +121,7 @@ def reverse_ordering(ordering_tuple):
Given an order_by tuple such as `('-created', 'uuid')` reverse the
ordering and return a new tuple, eg. `('created', '-uuid')`.
"""

def invert(x):
return x[1:] if (x.startswith('-')) else '-' + x

Expand All @@ -34,41 +142,33 @@ def __getitem__(self, key):
return self.items.__getitem__(key)

def __repr__(self):
return '<Page: [%s%s]>' % (', '.join(repr(i) for i in self.items[:21]), ' (remaining truncated)' if len(self.items) > 21 else '')
return '<Page: [%s%s]>' % (
', '.join(repr(i) for i in self.items[:21]),
' (remaining truncated)' if len(self.items) > 21 else '',
)


class CursorPaginator(object):
delimiter = '|'
none_string = '::None'
invalid_cursor_message = _('Invalid cursor')

def __init__(self, queryset, ordering):
self.queryset = queryset.order_by(*self._nulls_ordering(ordering))
def __init__(self, queryset, ordering, strategy=None):
self.ordering = ordering
self.strategy = strategy or DefaultCursorStrategy()
self.queryset = queryset.order_by(*self._get_ordering(ordering))

def _nulls_ordering(self, ordering, from_last=False):
"""
This clarifies that NULL value comes at the end in the sort.
When "from_last" is specified, NULL value comes first since we return the results in reversed order.
"""
nulls_ordering = []
for key in ordering:
is_reversed = key.startswith('-')
column = key.lstrip('-')
if is_reversed:
if from_last:
nulls_ordering.append(F(column).desc(nulls_first=True))
else:
nulls_ordering.append(F(column).desc(nulls_last=True))
else:
if from_last:
nulls_ordering.append(F(column).asc(nulls_first=True))
else:
nulls_ordering.append(F(column).asc(nulls_last=True))
def _get_ordering(self, ordering):
"""Get database ordering using the current strategy."""
return self.strategy.get_ordering(ordering)

return nulls_ordering
def _nulls_ordering(self, ordering, from_last=False):
"""Deprecated: Use strategy.get_ordering instead."""
return self.strategy.get_ordering(ordering, from_last)

def _apply_paginator_arguments(self, qs, first=None, last=None, after=None, before=None):
def _apply_paginator_arguments(
self, qs, first=None, last=None, after=None, before=None
):
"""
Apply first/after, last/before filtering to the queryset
"""
Expand All @@ -81,9 +181,11 @@ def _apply_paginator_arguments(self, qs, first=None, last=None, after=None, befo
if before is not None:
qs = self.apply_cursor(before, qs, from_last=from_last, reverse=True)
if first is not None:
qs = qs[:first + 1]
qs = qs[: first + 1]
if last is not None:
qs = qs.order_by(*self._nulls_ordering(reverse_ordering(self.ordering), from_last=True))[:last + 1]
qs = qs.order_by(
*self._nulls_ordering(reverse_ordering(self.ordering), from_last=True)
)[: last + 1]

return qs

Expand Down Expand Up @@ -128,90 +230,28 @@ async def apage(self, first=None, last=None, after=None, before=None):
return self._get_cursor_page(items, has_additional, first, last, after, before)

def apply_cursor(self, cursor, queryset, from_last, reverse=False):
"""Apply cursor using the current strategy."""
position = self.decode_cursor(cursor)

# this was previously implemented as tuple comparison done on postgres side
# Assume comparing 3-tuples a and b,
# the comparison a < b is equivalent to:
# (a.0 < b.0) || (a.0 == b.0 && (a.1 < b.1)) || (a.0 == b.0 && a.1 == b.1 && (a.2 < b.2))
# The expression above does not depend on short-circuit evalution support,
# which is usually unavailable on backend RDB

# In order to reflect that in DB query,
# we need to generate a corresponding WHERE-clause.

# Suppose we have ordering ("field1", "-field2", "field3")
# (note negation 2nd item),
# and corresponding cursor values are ("value1", "value2", "value3"),
# `reverse` is False.
# In order to apply cursor, we need to generate a following WHERE-clause:

# WHERE ((field1 < value1 OR field1 IS NULL) OR
# (field1 = value1 AND (field2 > value2 OR field2 IS NULL)) OR
# (field1 = value1 AND field2 = value2 AND (field3 < value3 IS NULL)).
#
# Keep in mind, NULL is considered the last part of each field's order.
# We will use `__lt` lookup for `<`,
# `__gt` for `>` and `__exact` for `=`.
# (Using case-sensitive comparison as long as
# cursor values come from the DB against which it is going to be compared).
# The corresponding django ORM construct would look like:
# filter(
# Q(field1__lt=Value(value1) OR field1__isnull=True) |
# Q(field1__exact=Value(value1), (Q(field2__gt=Value(value2) | Q(field2__isnull=True)) |
# Q(field1__exact=Value(value1), field2__exact=Value(value2), (Q(field3__lt=Value(value3) | Q(field3__isnull=True)))
# )

# In order to remember which keys we need to compare for equality on the next iteration,
# we need an accumulator in which we store all the previous keys.
# When we are generating a Q object for j-th position/ordering pair,
# our q_equality would contain equality lookups
# for previous pairs of 0-th to (j-1)-th pairs.
# That would allow us to generate a Q object like the following:
# Q(f1__exact=Value(v1), f2__exact=Value(v2), ..., fj_1__exact=Value(vj_1), fj__lt=Value(vj)),
# where the last item would depend on both "reverse" option and ordering key sign.

filtering = Q()
q_equality = {}

position_values = [Value(pos, output_field=TextField()) if pos is not None else None for pos in position]

for ordering, value in zip(self.ordering, position_values):
is_reversed = ordering.startswith('-')
o = ordering.lstrip('-')
if value is None: # cursor value for the key was NULL
key = "{}__isnull".format(o)
if from_last is True: # if from_last & cursor value is NULL, we need to get non Null for the key
q = {key : False}
q.update(q_equality)
filtering |= Q(**q)

q_equality.update({key: True})
else: # cursor value for the key was non NULL
if reverse != is_reversed:
comparison_key = "{}__lt".format(o)
else:
comparison_key = "{}__gt".format(o)

q = Q(**{comparison_key: value})
if not from_last: # if not from_last, NULL values are still candidates
q |= Q(**{"{}__isnull".format(o): True})
filtering |= (q) & Q(**q_equality)

equality_key = "{}__exact".format(o)
q_equality.update({equality_key: value})

return queryset.filter(filtering)
return queryset.filter(
self.strategy.build_cursor_filter(
self.ordering, position, reverse, from_last
)
)

def decode_cursor(self, cursor):
try:
orderings = b64decode(cursor.encode('ascii')).decode('utf8')
return [ordering if ordering != self.none_string else None for ordering in orderings.split(self.delimiter)]
return [
ordering if ordering != self.none_string else None
for ordering in orderings.split(self.delimiter)
]
except (TypeError, ValueError):
raise InvalidCursor(self.invalid_cursor_message)

def encode_cursor(self, position):
encoded = b64encode(self.delimiter.join(position).encode('utf8')).decode('ascii')
encoded = b64encode(self.delimiter.join(position).encode('utf8')).decode(
'ascii'
)
return encoded

def position_from_instance(self, instance):
Expand All @@ -231,3 +271,12 @@ def position_from_instance(self, instance):
def cursor(self, instance):
return self.encode_cursor(self.position_from_instance(instance))

@classmethod
def for_preserve_ordering(cls, queryset, ordering):
"""Create cursor paginator using PreserveOrderingStrategy."""
return cls(queryset, ordering, strategy=PreserveOrderingStrategy())

@classmethod
def with_strategy(cls, queryset, ordering, strategy):
"""Create a cursor paginator with a custom strategy."""
return cls(queryset, ordering, strategy=strategy)
1 change: 0 additions & 1 deletion runtests.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from django.conf import settings
from django.test.utils import get_runner


if __name__ == '__main__':
os.environ['DJANGO_SETTINGS_MODULE'] = 'tests.settings'
django.setup()
Expand Down
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from setuptools import setup


with open("README.md", "r") as fh:
long_description = fh.read()

Expand Down
Loading