Skip to content

Commit 23beed7

Browse files
sbenzaquencopybara-github
authored andcommitted
Add Inserter class that avoids data reloads when rehashing/inserting many elements.
This class caches all the relevant fields from the map to keep them in the stack/registers. PiperOrigin-RevId: 820727393
1 parent 69dbf46 commit 23beed7

File tree

2 files changed

+101
-11
lines changed

2 files changed

+101
-11
lines changed

src/google/protobuf/map.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,14 @@ void UntypedMapBase::DeleteNode(NodeBase* node) {
137137
DeallocNode(node);
138138
}
139139

140+
void UntypedMapBase::DeleteList(NodeBase* list) {
141+
while (list != nullptr) {
142+
NodeBase* n = list;
143+
list = list->next;
144+
DeleteNode(n);
145+
}
146+
}
147+
140148
void UntypedMapBase::ClearTableImpl(Arena* arena, bool reset) {
141149
ABSL_DCHECK_NE(num_buckets_, kGlobalEmptyTableSize);
142150
ABSL_DCHECK_EQ(arena, this->arena());

src/google/protobuf/map.h

Lines changed: 93 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,7 @@ class PROTOBUF_EXPORT UntypedMapBase {
514514
}
515515

516516
void DeleteNode(NodeBase* node);
517+
void DeleteList(NodeBase* list);
517518

518519
map_index_t num_elements_;
519520
map_index_t num_buckets_;
@@ -857,17 +858,44 @@ class KeyMapBase : public UntypedMapBase {
857858
void InsertOrReplaceNodes(Arena* arena, KeyNode* list, map_index_t count) {
858859
ResizeIfLoadIsOutOfRangeForMultiInsert(arena, num_elements_ + count);
859860

860-
while (list != nullptr) {
861-
auto* node = list;
861+
map_index_t new_size = num_elements_;
862+
863+
Inserter inserter(this);
864+
NodeBase* list_to_delete = nullptr;
865+
866+
for (map_index_t i = 0; i < count; ++i) {
867+
ABSL_DCHECK_NE(list, nullptr);
868+
auto* node_to_insert = list;
862869
list = static_cast<KeyNode*>(list->next);
863870

864-
auto p = this->FindHelper(node->key());
865-
map_index_t b = p.bucket;
866-
if (ABSL_PREDICT_FALSE(p.node != nullptr)) {
867-
EraseImpl(arena, p.bucket, static_cast<KeyNode*>(p.node), true);
871+
const map_index_t b = inserter.BucketNumber(node_to_insert);
872+
for (NodeBase** node_prev = &table_[b];;
873+
node_prev = &(*node_prev)->next) {
874+
KeyNode* n = static_cast<KeyNode*>(*node_prev);
875+
876+
if (n == nullptr) {
877+
// Reached the end without finding anything. Insert it.
878+
inserter.InsertUnique(node_to_insert, b);
879+
++new_size;
880+
break;
881+
}
882+
883+
if (ABSL_PREDICT_FALSE(n->key() == node_to_insert->key())) {
884+
// Let's just replace the node right there.
885+
*node_prev = node_to_insert;
886+
node_to_insert->next = n->next;
887+
// And append this node to a list to delete later.
888+
n->next = list_to_delete;
889+
list_to_delete = n;
890+
break;
891+
}
868892
}
869-
InsertUnique(b, node);
870-
++num_elements_;
893+
}
894+
895+
num_elements_ = new_size;
896+
897+
if (ABSL_PREDICT_FALSE(arena == nullptr && list_to_delete != nullptr)) {
898+
DeleteList(list_to_delete);
871899
}
872900
}
873901

@@ -1001,11 +1029,13 @@ class KeyMapBase : public UntypedMapBase {
10011029
ResizeIfLoadIsOutOfRangeForMultiInsert(arena, num_nodes);
10021030
num_elements_ = num_nodes;
10031031
AssertLoadFactor();
1004-
while (head != nullptr) {
1032+
Inserter inserter(this);
1033+
for (size_t i = 0; i < num_nodes; ++i) {
10051034
KeyNode* node = static_cast<KeyNode*>(head);
1035+
ABSL_DCHECK_NE(node, nullptr);
10061036
head = head->next;
10071037
absl::PrefetchToLocalCacheNta(head);
1008-
InsertUnique(BucketNumber(TS::ToView(node->key())), node);
1038+
inserter.InsertUnique(node);
10091039
}
10101040
}
10111041

@@ -1037,18 +1067,70 @@ class KeyMapBase : public UntypedMapBase {
10371067
const map_index_t start = index_of_first_non_null_;
10381068
index_of_first_non_null_ = num_buckets_;
10391069
#endif
1070+
Inserter inserter(this);
10401071
for (map_index_t i = start; i < old_table_size; ++i) {
10411072
for (KeyNode* node = static_cast<KeyNode*>(old_table[i]);
10421073
node != nullptr;) {
10431074
auto* next = static_cast<KeyNode*>(node->next);
1044-
InsertUnique(BucketNumber(TS::ToView(node->key())), node);
1075+
inserter.InsertUnique(node);
10451076
node = next;
10461077
}
10471078
}
10481079
DeleteTable(arena, old_table, old_table_size);
10491080
AssertLoadFactor();
10501081
}
10511082

1083+
// Caches all the relevant values of `UntypedMapBase` to hold a copy on the
1084+
// stack and avoid reloads after every write.
1085+
// It allows inserting multiple nodes in a row with reduced cost.
1086+
class Inserter {
1087+
public:
1088+
explicit Inserter(KeyMapBase* map)
1089+
: table_(map->table_),
1090+
mask_(map->num_buckets_ - 1),
1091+
#ifndef PROTOBUF_INTERNAL_REMOVE_ARENA_PTRS_MAP_FIELD
1092+
index_of_first_non_null_(map->index_of_first_non_null_),
1093+
#endif
1094+
map_(map) {
1095+
}
1096+
1097+
#ifndef PROTOBUF_INTERNAL_REMOVE_ARENA_PTRS_MAP_FIELD
1098+
~Inserter() {
1099+
// Flush the value at the end.
1100+
map_->index_of_first_non_null_ = index_of_first_non_null_;
1101+
}
1102+
#endif
1103+
1104+
map_index_t BucketNumber(KeyNode* node) const {
1105+
return Hash(node->key(), table_) & mask_;
1106+
}
1107+
1108+
void InsertUnique(KeyNode* node, map_index_t bucket) {
1109+
ABSL_DCHECK_EQ(bucket, BucketNumber(node));
1110+
auto*& head = table_[bucket];
1111+
if (head != nullptr && map_->ShouldInsertAfterHead(node)) {
1112+
node->next = head->next;
1113+
head->next = node;
1114+
} else {
1115+
node->next = head;
1116+
head = node;
1117+
#ifndef PROTOBUF_INTERNAL_REMOVE_ARENA_PTRS_MAP_FIELD
1118+
index_of_first_non_null_ = (std::min)(index_of_first_non_null_, bucket);
1119+
#endif
1120+
}
1121+
}
1122+
1123+
void InsertUnique(KeyNode* node) { InsertUnique(node, BucketNumber(node)); }
1124+
1125+
private:
1126+
NodeBase** const table_;
1127+
const map_index_t mask_;
1128+
#ifndef PROTOBUF_INTERNAL_REMOVE_ARENA_PTRS_MAP_FIELD
1129+
map_index_t index_of_first_non_null_;
1130+
#endif
1131+
KeyMapBase* const map_;
1132+
};
1133+
10521134
map_index_t BucketNumber(typename TS::ViewType k) const {
10531135
return Hash(k, table_) & (num_buckets_ - 1);
10541136
}

0 commit comments

Comments
 (0)