Skip to content

Commit 5849dd3

Browse files
committed
Add support for loading sharded Riegeli frequency files.
Sharded data files are of the form <filename>-?????-of-?????. Where the data for a single frequency data set is split among multiple files.
1 parent edcbdae commit 5849dd3

19 files changed

+184
-20
lines changed

util/BUILD

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,9 @@ cc_test(
217217
"testdata/codepoints_with_freq_invalid.txt",
218218
"testdata/test_freq_data.riegeli",
219219
"testdata/invalid_test_freq_data.riegeli",
220-
],
220+
] + glob([
221+
"testdata/sharded/*",
222+
]),
221223
deps = [
222224
":load_codepoints",
223225
"@googletest//:gtest_main",

util/freq_data_to_sorted_codepoints.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <iomanip>
55
#include <iostream>
66
#include <locale>
7+
#include <ostream>
78
#include <string>
89
#include <vector>
910

@@ -40,7 +41,11 @@ int main(int argc, char** argv) {
4041

4142
if (args.size() != 2) {
4243
std::cerr << "Usage:" << std::endl
43-
<< "freq_data_to_sorted_codepoints <riegeli_file>" << std::endl;
44+
<< "freq_data_to_sorted_codepoints <riegeli_file>" << std::endl
45+
<< std::endl
46+
<< "Append @* to the file name to load sharded data files. "
47+
<< "For example \"<path>@*\" will load all files of the form <path>-?????-of-?????"
48+
<< std::endl;
4449
return -1;
4550
}
4651

util/generate_riegeli_test_data.cc

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,17 @@
66
#include "absl/flags/flag.h"
77
#include "absl/flags/parse.h"
88
#include "absl/status/status.h"
9+
#include "absl/strings/str_cat.h"
910
#include "riegeli/bytes/fd_writer.h"
1011
#include "riegeli/records/record_writer.h"
1112
#include "util/unicode_count.pb.h"
1213

14+
using absl::StrCat;
15+
1316
ABSL_FLAG(std::string, output_path, "", "Path to write the output file.");
1417

1518
ABSL_FLAG(bool, include_invalid_record, false, "If set add an invalid record.");
19+
ABSL_FLAG(bool, shard, false, "If set shard into multiple files.");
1620

1721
namespace {
1822

@@ -41,24 +45,44 @@ absl::Status Main() {
4145
message4.add_codepoints(0x44);
4246
message4.set_count(75);
4347

44-
riegeli::RecordWriter writer{riegeli::FdWriter(output_path)};
45-
writer.WriteRecord(message1);
46-
writer.WriteRecord(message2);
47-
writer.WriteRecord(message3);
48-
writer.WriteRecord(message4);
49-
50-
if (absl::GetFlag(FLAGS_include_invalid_record)) {
51-
CodepointCount message5;
52-
message5.add_codepoints(0x46);
53-
message5.add_codepoints(0x46);
54-
message5.add_codepoints(0x46);
55-
message5.set_count(75);
56-
writer.WriteRecord(message5);
57-
}
48+
if (!absl::GetFlag(FLAGS_shard)) {
49+
riegeli::RecordWriter writer{riegeli::FdWriter(output_path)};
50+
writer.WriteRecord(message1);
51+
writer.WriteRecord(message2);
52+
writer.WriteRecord(message3);
53+
writer.WriteRecord(message4);
5854

59-
if (!writer.Close()) {
60-
return writer.status();
55+
if (absl::GetFlag(FLAGS_include_invalid_record)) {
56+
CodepointCount message5;
57+
message5.add_codepoints(0x46);
58+
message5.add_codepoints(0x46);
59+
message5.add_codepoints(0x46);
60+
message5.set_count(75);
61+
writer.WriteRecord(message5);
62+
}
63+
64+
if (!writer.Close()) {
65+
return writer.status();
66+
}
67+
} else {
68+
{
69+
riegeli::RecordWriter writer{riegeli::FdWriter(StrCat(output_path, "-00000-of-00003"))};
70+
writer.WriteRecord(message1);
71+
writer.WriteRecord(message2);
72+
writer.Close();
73+
}
74+
{
75+
riegeli::RecordWriter writer{riegeli::FdWriter(StrCat(output_path, "-00001-of-00003"))};
76+
writer.WriteRecord(message3);
77+
writer.Close();
78+
}
79+
{
80+
riegeli::RecordWriter writer{riegeli::FdWriter(StrCat(output_path, "-00002-of-00003"))};
81+
writer.WriteRecord(message4);
82+
writer.Close();
83+
}
6184
}
85+
6286
return absl::OkStatus();
6387
}
6488

util/load_codepoints.cc

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
#include "load_codepoints.h"
22

3+
#include <algorithm>
4+
#include <filesystem>
35
#include <fstream>
46
#include <iostream>
57
#include <optional>
8+
#include <regex>
69
#include <sstream>
710

811
#include "absl/strings/str_cat.h"
@@ -14,6 +17,7 @@
1417
#include "riegeli/records/record_reader.h"
1518
#include "util/unicode_count.pb.h"
1619

20+
using absl::Status;
1721
using absl::StatusOr;
1822
using absl::StrCat;
1923
using absl::string_view;
@@ -122,8 +126,53 @@ StatusOr<std::vector<CodepointAndFrequency>> LoadCodepointsOrdered(
122126
return out;
123127
}
124128

125-
StatusOr<UnicodeFrequencies> LoadFrequenciesFromRiegeli(const char* path) {
126-
UnicodeFrequencies frequencies;
129+
StatusOr<std::vector<std::string>> ExpandShardedPath(const char* path) {
130+
std::string full_path(path);
131+
132+
if (!full_path.ends_with("@*")) {
133+
if (!std::filesystem::exists(full_path)) {
134+
return absl::NotFoundError(StrCat("Path does not exist: ", full_path));
135+
}
136+
return std::vector<std::string>{full_path};
137+
}
138+
139+
std::filesystem::path file_path = full_path.substr(0, full_path.size() - 2);
140+
std::string base_name = file_path.filename();
141+
std::filesystem::path directory = file_path.parent_path();
142+
143+
// Find the list of files matching the pattern:
144+
// <base name>-?????-of-?????
145+
std::regex file_pattern("^.*-[0-9]{5}-of-[0-9]{5}$");
146+
147+
if (!std::filesystem::exists(directory) ||
148+
!std::filesystem::is_directory(directory)) {
149+
return absl::NotFoundError(StrCat(
150+
"Path does not exist or is not a directory: ", directory.string()));
151+
}
152+
153+
// Collect into a set to ensure the output is sorted.
154+
absl::btree_set<std::string> files;
155+
for (const auto& entry : std::filesystem::directory_iterator(directory)) {
156+
std::string name = entry.path().filename();
157+
if (!name.starts_with(base_name)) {
158+
continue;
159+
}
160+
161+
if (std::regex_match(name, file_pattern)) {
162+
files.insert(entry.path());
163+
}
164+
}
165+
166+
if (files.empty()) {
167+
return absl::NotFoundError(StrCat("No files matched the shard pattern: ", full_path));
168+
}
169+
170+
return std::vector<std::string>(files.begin(), files.end());
171+
}
172+
173+
static Status LoadFrequenciesFromRiegeliIndividual(
174+
const char* path, UnicodeFrequencies& frequencies
175+
) {
127176
riegeli::RecordReader reader{riegeli::FdReader(path)};
128177
if (!reader.ok()) {
129178
return absl::InvalidArgumentError(
@@ -144,6 +193,15 @@ StatusOr<UnicodeFrequencies> LoadFrequenciesFromRiegeli(const char* path) {
144193
if (!reader.Close()) {
145194
return absl::InternalError(reader.status().message());
146195
}
196+
return absl::OkStatus();
197+
}
198+
199+
StatusOr<UnicodeFrequencies> LoadFrequenciesFromRiegeli(const char* path) {
200+
auto paths = TRY(ExpandShardedPath(path));
201+
UnicodeFrequencies frequencies;
202+
for (const auto& path : paths) {
203+
TRYV(LoadFrequenciesFromRiegeliIndividual(path.c_str(), frequencies));
204+
}
147205
return frequencies;
148206
}
149207

util/load_codepoints.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,21 @@ absl::StatusOr<common::FontData> LoadFile(const char* path);
3636

3737
// Loads a Riegeli file of CodepointCount protos and returns a
3838
// UnicodeFrequencies instance.
39+
//
40+
// Append "@*" to the path to load all sharded files for this path.
41+
// For example "FrequencyData.riegeli@*" will load all files of the
42+
// form FrequencyData.riegeli-*-of-* into the frequency data set.
3943
absl::StatusOr<ift::freq::UnicodeFrequencies> LoadFrequenciesFromRiegeli(
4044
const char* path);
4145

46+
// Given a filepath if it ends with @* this will expand the path into
47+
// the list of paths matching the pattern: <path>-?????-of-?????
48+
// Otherwise returns just the input path.
49+
//
50+
// Checks that the input path exists and will return a NotFoundError if
51+
// it does not.
52+
absl::StatusOr<std::vector<std::string>> ExpandShardedPath(const char* path);
53+
4254
struct CodepointAndFrequency {
4355
uint32_t codepoint;
4456
std::optional<uint64_t> frequency;

util/load_codepoints_test.cc

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,73 @@ TEST_F(LoadCodepointsTest, LoadFrequenciesFromRiegeli) {
8686
EXPECT_EQ(result->ProbabilityFor(0x44, 0x45), 0.25);
8787
}
8888

89+
TEST_F(LoadCodepointsTest, LoadFrequenciesFromRiegeli_Sharded) {
90+
auto result =
91+
util::LoadFrequenciesFromRiegeli("util/testdata/sharded/test_freq_data.riegeli@*");
92+
ASSERT_TRUE(result.ok()) << result.status();
93+
94+
EXPECT_EQ(result->ProbabilityFor(0x43, 0x43), 1.0);
95+
EXPECT_EQ(result->ProbabilityFor(0x44, 0x44), 75.0 / 200.0);
96+
97+
EXPECT_EQ(result->ProbabilityFor(0x41, 0x42), 0.5);
98+
EXPECT_EQ(result->ProbabilityFor(0x44, 0x45), 0.25);
99+
}
100+
101+
TEST_F(LoadCodepointsTest, LoadFrequenciesFromRiegeli_Sharded_DoesNotExist) {
102+
auto result =
103+
util::LoadFrequenciesFromRiegeli("util/testdata/sharded/notfound.riegeli@*");
104+
ASSERT_TRUE(absl::IsNotFound(result.status())) << result.status();
105+
}
106+
89107
TEST_F(LoadCodepointsTest, LoadFrequenciesFromRiegeli_BadData) {
90108
auto result = util::LoadFrequenciesFromRiegeli(
91109
"util/testdata/invalid_test_freq_data.riegeli");
92110
ASSERT_TRUE(absl::IsInvalidArgument(result.status())) << result.status();
93111
}
94112

113+
TEST_F(LoadCodepointsTest, ExpandShardedPath) {
114+
auto result = ExpandShardedPath("util/testdata/test_freq_data.riegeli");
115+
ASSERT_TRUE(result.ok()) << result.status();
116+
ASSERT_EQ(*result,
117+
(std::vector<std::string>{"util/testdata/test_freq_data.riegeli"}));
118+
119+
result = ExpandShardedPath("util/testdata/test_freq_data.riegeli@*");
120+
ASSERT_TRUE(absl::IsNotFound(result.status())) << result.status();
121+
122+
result = ExpandShardedPath("util/testdata/sharded/BadSuffix@*");
123+
ASSERT_TRUE(absl::IsNotFound(result.status())) << result.status();
124+
125+
result = ExpandShardedPath("does/not/exist.file@*");
126+
ASSERT_TRUE(absl::IsNotFound(result.status())) << result.status();
127+
128+
result = ExpandShardedPath("util/testdata/sharded/notfound.file@*");
129+
ASSERT_TRUE(absl::IsNotFound(result.status())) << result.status();
130+
131+
result = ExpandShardedPath("does/not/exist.file");
132+
ASSERT_TRUE(absl::IsNotFound(result.status())) << result.status();
133+
134+
result = ExpandShardedPath("util/testdata/sharded/Language_ja.riegeli@*");
135+
ASSERT_TRUE(result.ok()) << result.status();
136+
ASSERT_EQ(*result,
137+
(std::vector<std::string>{
138+
"util/testdata/sharded/Language_ja.riegeli-00000-of-00003",
139+
"util/testdata/sharded/Language_ja.riegeli-00001-of-00003",
140+
"util/testdata/sharded/Language_ja.riegeli-00002-of-00003",
141+
}));
142+
143+
result = ExpandShardedPath("util/testdata/sharded/Language_ko.riegeli@*");
144+
ASSERT_TRUE(result.ok()) << result.status();
145+
ASSERT_EQ(*result,
146+
(std::vector<std::string>{
147+
"util/testdata/sharded/Language_ko.riegeli-00000-of-00100",
148+
"util/testdata/sharded/Language_ko.riegeli-00008-of-00100",
149+
"util/testdata/sharded/Language_ko.riegeli-00011-of-00100",
150+
"util/testdata/sharded/Language_ko.riegeli-00013-of-00100",
151+
"util/testdata/sharded/Language_ko.riegeli-00020-of-00100",
152+
}));
153+
154+
result = ExpandShardedPath("util/testdata/sharded/Language_ja.riegeli");
155+
ASSERT_TRUE(absl::IsNotFound(result.status())) << result.status();
156+
}
157+
95158
} // namespace util

util/testdata/sharded/BadSuffix-00-of-02

Whitespace-only changes.

util/testdata/sharded/BadSuffix-01-of-02

Whitespace-only changes.

util/testdata/sharded/Language_ja.riegeli-00000-of-00003

Whitespace-only changes.

util/testdata/sharded/Language_ja.riegeli-00001-of-00003

Whitespace-only changes.

0 commit comments

Comments
 (0)