Skip to content

Commit 6c23135

Browse files
committed
[Iceberg] add rollback_to_previous_snapshot
1 parent 0616f79 commit 6c23135

File tree

1 file changed

+166
-5
lines changed

1 file changed

+166
-5
lines changed

octobot/community/history_backend/iceberg_historical_backend_client.py

Lines changed: 166 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
import pyiceberg.table
3434
import pyiceberg.table.sorting
3535
import pyiceberg.table.statistics
36+
import pyiceberg.table.update
37+
import pyiceberg.table.refs
3638

3739
import octobot_commons.logging as commons_logging
3840
import octobot_commons.enums as commons_enums
@@ -263,6 +265,8 @@ async def insert_candles_history(self, rows: list, column_names: list) -> None:
263265
)
264266

265267
def _sync_insert_candles_history(self, rows: typing.Iterable[list], column_names: list[str]) -> None:
268+
# warning: try not to insert duplicate candles,
269+
# however duplicates will be deduplicated later on anyway
266270
if not rows:
267271
return
268272
schema = self._pyarrow_get_ohlcv_schema()
@@ -275,12 +279,22 @@ def _sync_insert_candles_history(self, rows: typing.Iterable[list], column_names
275279
for i, _ in enumerate(column_names)
276280
]
277281
pa_table = pyarrow.Table.from_arrays(pa_arrays, schema=schema)
282+
try:
283+
table.append(pa_table)
284+
# note: alternative upsert syntax could prevent duplicates but is really slow and silentlycrashes the process
285+
# when used with a few thousand rows
286+
# table.upsert(pa_table, join_cols=["timestamp", "exchange_internal_name", "symbol", "time_frame"])
287+
except pyiceberg.exceptions.CommitFailedException as err:
288+
# if this happens, it means there are conflicts. Log error and let maintenance
289+
# perform the rollback to previous snapshot
290+
self._get_logger().exception(
291+
err,
292+
True,
293+
f"Commit failed: conflicts. Rolling back to previous snapshot might fix this {err}"
294+
)
295+
raise
296+
# now that candles have been inserted, update metadata
278297
self._register_updated_min_max(table, pa_table)
279-
# warning: try not to insert duplicate candles, duplicates will be deduplicated later on anyway
280-
table.append(pa_table)
281-
# note: alternative upsert syntax could prevent duplicates but is really slow and silentlycrashes the process
282-
# when used with a few thousand rows
283-
# table.upsert(pa_table, join_cols=["timestamp", "exchange_internal_name", "symbol", "time_frame"])
284298
self._get_logger().info(
285299
f"Successfully inserted {len(rows)} rows into "
286300
f"{TableNames.OHLCV_HISTORY.value} for {pa_table['exchange_internal_name'][0]}:{pa_table['symbol'][0]}:{pa_table['time_frame'][0]}"
@@ -312,6 +326,11 @@ async def _insert_and_reset_pending_data(self):
312326
# reset pending data
313327
self._pending_insert_data_by_table = {}
314328
for table_name, pending_data in to_insert_pending_data.items():
329+
pending_rows_count = sum(
330+
len(data.data)
331+
for data in pending_data
332+
)
333+
self._get_logger().info(f"Inserting {table_name.value} {pending_rows_count} pending rows")
315334
await self._run_in_executor(self._sync_insert_table_pending_data, table_name, pending_data)
316335

317336
def _sync_insert_table_pending_data(self, table_name: TableNames, pending_data: list[_PendingInsertData]):
@@ -738,3 +757,145 @@ def _get_logger(cls):
738757

739758
def _has_metadata_to_update(self) -> bool:
740759
return bool(self._updated_min_max_per_symbol_per_time_frame_per_exchange)
760+
761+
async def rollback_to_previous_snapshot(
762+
self, table_name: TableNames, snapshot_id: typing.Optional[int] = None
763+
) -> int:
764+
return await self._run_in_executor(
765+
self._sync_rollback_to_previous_snapshot, table_name, snapshot_id
766+
)
767+
768+
def _sync_rollback_to_previous_snapshot(
769+
self, table_name: TableNames, snapshot_id: typing.Optional[int] = None
770+
) -> int:
771+
"""Synchronous implementation of rollback_to_previous_snapshot"""
772+
table = self._get_or_create_table(table_name)
773+
774+
# Get the snapshot history
775+
history = table.history()
776+
777+
if not history:
778+
raise ValueError(f"No snapshot history found for table {table_name.value}")
779+
780+
# Sort snapshots by timestamp (most recent first)
781+
sorted_snapshots = sorted(
782+
history,
783+
key=lambda s: s.timestamp_ms,
784+
reverse=True
785+
)
786+
787+
if snapshot_id is not None:
788+
# Rollback to specific snapshot
789+
target_snapshot_id = snapshot_id
790+
if not any(s.snapshot_id == snapshot_id for s in sorted_snapshots):
791+
raise ValueError(
792+
f"Snapshot ID {snapshot_id} not found in table {table_name.value} history"
793+
)
794+
else:
795+
# Rollback to previous snapshot (second most recent)
796+
if len(sorted_snapshots) < 2:
797+
raise ValueError(
798+
f"No previous snapshot to rollback to for table {table_name.value}. "
799+
f"Only {len(sorted_snapshots)} snapshot(s) available."
800+
)
801+
# sorted_snapshots[0] is the current snapshot, [1] is the previous one
802+
target_snapshot_id = sorted_snapshots[1].snapshot_id
803+
804+
# Perform the rollback by updating the main branch to point to the target snapshot
805+
# This is the PyIceberg way to "rollback" - you update the branch reference to a previous snapshot
806+
self._get_logger().info(
807+
f"Rolling back table {table_name.value} to snapshot {target_snapshot_id}"
808+
)
809+
810+
# Use transaction to update the main branch reference to the target snapshot
811+
# We create the update directly to avoid the AssertRefSnapshotId requirement that causes 409 conflicts
812+
with table.transaction() as txn:
813+
# Create the SetSnapshotRefUpdate directly without assertion requirements
814+
update = pyiceberg.table.update.SetSnapshotRefUpdate(
815+
ref_name=pyiceberg.table.refs.MAIN_BRANCH,
816+
type=pyiceberg.table.refs.SnapshotRefType.BRANCH,
817+
snapshot_id=target_snapshot_id,
818+
)
819+
txn._updates += (update,)
820+
# Warning: don't add requirements - this allows the update to succeed even if there are concurrent changes
821+
# (otherwise, the update will fail with a 409 conflict in case the target branch is corrupted)
822+
823+
self._get_logger().info(
824+
f"Successfully rolled back table {table_name.value} to snapshot {target_snapshot_id}"
825+
)
826+
827+
return target_snapshot_id
828+
829+
830+
async def cleanup_old_snapshots_and_branches(
831+
self,
832+
table_name: TableNames,
833+
older_than_s: int,
834+
branches_to_delete: typing.Optional[list[str]] = None
835+
) -> tuple[int, int]:
836+
return await self._run_in_executor(
837+
self._sync_cleanup_old_snapshots_and_branches,
838+
table_name,
839+
older_than_s,
840+
branches_to_delete
841+
)
842+
843+
def _sync_cleanup_old_snapshots_and_branches(
844+
self,
845+
table_name: TableNames,
846+
older_than_s: float,
847+
branches_to_delete: typing.Optional[list[str]] = None
848+
) -> tuple[int, int]:
849+
"""Synchronous implementation of cleanup_old_snapshots_and_branches"""
850+
table = self._get_or_create_table(table_name)
851+
snapshots_expired = 0
852+
branches_deleted = 0
853+
854+
table.maintenance.expire_snapshots().older_than(
855+
datetime.datetime.fromtimestamp(
856+
older_than_s, tz=datetime.timezone.utc
857+
)
858+
).commit()
859+
860+
# Delete branches if branches_to_delete is provided
861+
if branches_to_delete:
862+
# Refresh table metadata to get current branches
863+
table.refresh()
864+
current_refs = table.metadata.refs
865+
866+
# Filter to only delete branches that exist and are not the main branch
867+
valid_branches_to_delete = [
868+
branch for branch in branches_to_delete
869+
if branch in current_refs and branch != pyiceberg.table.refs.MAIN_BRANCH
870+
]
871+
872+
if valid_branches_to_delete:
873+
self._get_logger().info(
874+
f"Deleting {len(valid_branches_to_delete)} branches for table {table_name.value}: "
875+
f"{', '.join(valid_branches_to_delete)}"
876+
)
877+
878+
# Delete branches one by one using manage_snapshots
879+
with table.manage_snapshots() as ms:
880+
for branch in valid_branches_to_delete:
881+
ms.remove_branch(branch_name=branch)
882+
883+
branches_deleted = len(valid_branches_to_delete)
884+
self._get_logger().info(
885+
f"Successfully deleted {branches_deleted} branches for table {table_name.value}"
886+
)
887+
elif branches_to_delete:
888+
# Log warning if some branches were requested but don't exist or are protected
889+
invalid_branches = [b for b in branches_to_delete if b not in current_refs or b == pyiceberg.table.refs.MAIN_BRANCH]
890+
if invalid_branches:
891+
self._get_logger().warning(
892+
f"Skipped deletion of {len(invalid_branches)} branches (non-existent or protected): "
893+
f"{', '.join(invalid_branches)}"
894+
)
895+
896+
self._get_logger().info(
897+
f"Cleanup complete for table {table_name.value}: "
898+
f"{snapshots_expired} snapshots expired, {branches_deleted} branches deleted"
899+
)
900+
901+
return snapshots_expired, branches_deleted

0 commit comments

Comments
 (0)