Skip to content

Commit bd3e235

Browse files
DFrenkelsrenatus
authored andcommitted
builtins: product, sort, sum
Also: implementing internal.member_3 for compliance testing Resolves #12 Signed-off-by: Dmitry Frenkel <[email protected]>
1 parent 30dfa0b commit bd3e235

File tree

5 files changed

+437
-1
lines changed

5 files changed

+437
-1
lines changed

Sources/Rego/Builtins/Aggregates.swift

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,71 @@ extension BuiltinFuncs {
4646
arg: "collection", got: args[0].typeName, want: "array|set")
4747
}
4848
}
49+
50+
static func sort(ctx: BuiltinContext, args: [AST.RegoValue]) async throws -> AST.RegoValue {
51+
guard args.count == 1 else {
52+
throw BuiltinError.argumentCountMismatch(got: args.count, want: 1)
53+
}
54+
55+
switch args[0] {
56+
case .array(let a):
57+
return .array(a.sorted())
58+
case .set(let s):
59+
return .array(s.sorted())
60+
default:
61+
throw BuiltinError.argumentTypeMismatch(
62+
arg: "collection", got: args[0].typeName, want: "any<array[any], set[any]>")
63+
}
64+
}
65+
66+
static func sum(ctx: BuiltinContext, args: [AST.RegoValue]) async throws -> AST.RegoValue {
67+
return try await doReduce(ctx: ctx, args: args, initialValue: .number(0), op: BuiltinFuncs.plus)
68+
}
69+
70+
static func product(ctx: BuiltinContext, args: [AST.RegoValue]) async throws -> AST.RegoValue {
71+
return try await doReduce(ctx: ctx, args: args, initialValue: .number(1), op: BuiltinFuncs.mul)
72+
}
73+
74+
/// Returns reduction over an array or set of RegoValues with a given async Builtin being used as an reducer operation.
75+
/// Returns the normalized metric unit symbol for a given symbol.
76+
/// - Parameters:
77+
/// - ctx: The builtin context.
78+
/// - args: The arguments to reduce.
79+
/// - initialValue: The initial value to start with.
80+
/// - op: The Arithmetic builtin operation to be applied to the partial result an the next value in the sequence to produce the next result.
81+
private static func doReduce(
82+
ctx: BuiltinContext, args: [AST.RegoValue],
83+
initialValue: AST.RegoValue,
84+
op: (BuiltinContext, [AST.RegoValue]) async throws -> AST.RegoValue
85+
) async throws -> AST.RegoValue {
86+
guard args.count == 1 else {
87+
throw BuiltinError.argumentCountMismatch(got: args.count, want: 1)
88+
}
89+
90+
// We will iterate over this sequence
91+
var sequence: any Sequence<RegoValue>
92+
switch args[0] {
93+
case .array(let a):
94+
sequence = a
95+
case .set(let s):
96+
sequence = s
97+
default:
98+
throw BuiltinError.argumentTypeMismatch(
99+
arg: "collection", got: args[0].typeName, want: "any<array[number], set[number]>")
100+
}
101+
102+
do {
103+
// Can't use synchronous reduce here
104+
var result = initialValue
105+
for element in sequence {
106+
result = try await op(ctx, [result, element])
107+
}
108+
return result
109+
} catch is RegoError {
110+
let receivedTypes = Set(sequence.map({ $0.typeName })).joined(separator: ", ")
111+
throw BuiltinError.argumentTypeMismatch(
112+
arg: "collection", got: "\(args[0].typeName)[any<\(receivedTypes)>]",
113+
want: "any<array[number], set[number]>")
114+
}
115+
}
49116
}

Sources/Rego/Builtins/Collections.swift

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,47 @@ extension BuiltinFuncs {
2121
return .boolean(false)
2222
}
2323
}
24+
25+
// internal.member_3
26+
// memberOfWithKey is a membership check with key:
27+
// memberOfWithKey(k: any, x: any, y: any) checks if y has property or index k and it is equal to x
28+
// For objects, we are checking the keys AND the values.
29+
static func isMemberOfWithKey(ctx: BuiltinContext, args: [AST.RegoValue]) async throws -> AST.RegoValue {
30+
guard args.count == 3 else {
31+
throw BuiltinError.argumentCountMismatch(got: args.count, want: 3)
32+
}
33+
34+
let key = args[0]
35+
let value = args[1]
36+
// See https://github.com/open-policy-agent/opa/blob/b942136a4ad049262fd72026421dac6bdd705059/v1/topdown/aggregates.go#L247
37+
let match = args[2][key]
38+
if match != nil {
39+
return .boolean(RegoValue.compare(value, match!) == .orderedSame)
40+
}
41+
return .boolean(false)
42+
}
43+
}
44+
45+
extension RegoValue {
46+
/// Some RegoValues implement Get(key) interface. We will just implement it here as a subscript extension.
47+
/// See https://github.com/open-policy-agent/opa/blob/7ddaff2cc3dd749af25bab7d6a1f5a9cdbfe9833/v1/ast/term.go#L380
48+
fileprivate subscript(key: RegoValue) -> RegoValue? {
49+
switch self {
50+
case .object(let o):
51+
return o[key]
52+
case .array(let a):
53+
// key must be an integer position
54+
guard !key.isFloat, let index = key.integerValue else { return nil }
55+
// check bounds
56+
guard index >= 0 && index < a.count && index < Int32.max else { return nil }
57+
return a[Int(index)]
58+
case .set(let s):
59+
if s.contains(key) {
60+
return key
61+
}
62+
return nil
63+
default:
64+
return nil
65+
}
66+
}
2467
}

Sources/Rego/Builtins/Registry.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ public struct BuiltinRegistry: Sendable {
3838
"count": BuiltinFuncs.count,
3939
"max": BuiltinFuncs.max,
4040
"min": BuiltinFuncs.min,
41+
"product": BuiltinFuncs.product,
42+
"sort": BuiltinFuncs.sort,
43+
"sum": BuiltinFuncs.sum,
4144

4245
// Arithmetic
4346
"plus": BuiltinFuncs.plus,
@@ -65,6 +68,7 @@ public struct BuiltinRegistry: Sendable {
6568

6669
// Collections
6770
"internal.member_2": BuiltinFuncs.isMemberOf,
71+
"internal.member_3": BuiltinFuncs.isMemberOfWithKey,
6872

6973
// Comparison
7074
"gt": BuiltinFuncs.greaterThan,

Tests/RegoTests/BuiltinTests/AggregatesTests.swift

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,172 @@ extension BuiltinTests.AggregatesTests {
234234
),
235235
]
236236

237+
static let sumTests: [BuiltinTests.TestCase] = [
238+
BuiltinTests.TestCase(
239+
description: "empty array",
240+
name: "sum",
241+
args: [[]],
242+
expected: .success(0)
243+
),
244+
BuiltinTests.TestCase(
245+
description: "array",
246+
name: "sum",
247+
args: [[1, 2, 3.14, 4]],
248+
expected: .success(.number(NSDecimalNumber(decimal: 1 + 2 + 3.14 + 4)))
249+
),
250+
BuiltinTests.TestCase(
251+
description: "array of various objects",
252+
name: "sum",
253+
args: [[1, "a", "b"]],
254+
expected: .failure(
255+
BuiltinError.argumentTypeMismatch(
256+
arg: "collection", got: "array[any<number, string>]", want: "any<array[number], set[number]>"))
257+
),
258+
BuiltinTests.TestCase(
259+
description: "empty set",
260+
name: "sum",
261+
args: [.set([])],
262+
expected: .success(0)
263+
),
264+
BuiltinTests.TestCase(
265+
description: "set",
266+
name: "sum",
267+
args: [.set([1, 5, 8.65])],
268+
expected: .success(.number(NSDecimalNumber(decimal: 1 + 5 + 8.65)))
269+
),
270+
BuiltinTests.TestCase(
271+
description: "set of various objects",
272+
name: "sum",
273+
args: [.set([1, "a", "b"])],
274+
expected: .failure(
275+
BuiltinError.argumentTypeMismatch(
276+
arg: "collection", got: "set[any<number, string>]", want: "any<array[number], set[number]>"))
277+
),
278+
]
279+
280+
static let productTests: [BuiltinTests.TestCase] = [
281+
BuiltinTests.TestCase(
282+
description: "empty array",
283+
name: "product",
284+
args: [[]],
285+
expected: .success(1)
286+
),
287+
BuiltinTests.TestCase(
288+
description: "array",
289+
name: "product",
290+
args: [[1, 2, 3.14, 4]],
291+
expected: .success(.number(NSDecimalNumber(decimal: 1 * 2 * 3.14 * 4)))
292+
),
293+
BuiltinTests.TestCase(
294+
description: "array of various objects",
295+
name: "product",
296+
args: [[1, "a"]],
297+
expected: .failure(
298+
BuiltinError.argumentTypeMismatch(
299+
arg: "collection", got: "array<number, string>", want: "any<array[number], set[number]>"))
300+
),
301+
BuiltinTests.TestCase(
302+
description: "empty set",
303+
name: "product",
304+
args: [.set([])],
305+
expected: .success(1)
306+
),
307+
BuiltinTests.TestCase(
308+
description: "set",
309+
name: "product",
310+
args: [.set([1, 5, 8.65])],
311+
expected: .success(.number(NSDecimalNumber(decimal: 1 * 5 * 8.65)))
312+
),
313+
BuiltinTests.TestCase(
314+
description: "set of various objects",
315+
name: "product",
316+
args: [.set([1, "a"])],
317+
expected: .failure(
318+
BuiltinError.argumentTypeMismatch(
319+
arg: "collection", got: "set<number, string>", want: "any<array[number], set[number]>"))
320+
),
321+
]
322+
323+
static let sortTests: [BuiltinTests.TestCase] = [
324+
BuiltinTests.TestCase(
325+
description: "array",
326+
name: "sort",
327+
args: [[1, 100, 2]],
328+
expected: .success([1, 2, 100])
329+
),
330+
BuiltinTests.TestCase(
331+
description: "string array",
332+
name: "sort",
333+
args: [["b", "a"]],
334+
expected: .success(["a", "b"])
335+
),
336+
BuiltinTests.TestCase(
337+
description: "empty array",
338+
name: "sort",
339+
args: [[]],
340+
expected: .success([])
341+
),
342+
BuiltinTests.TestCase(
343+
description: "array of objects",
344+
name: "sort",
345+
args: [
346+
[["a": 1], ["a": 100], ["a": 3]]
347+
],
348+
expected: .success([["a": 1], ["a": 3], ["a": 100]])
349+
),
350+
BuiltinTests.TestCase(
351+
description: "array of objects with different keys",
352+
name: "sort",
353+
args: [
354+
[["a": 100], ["c": 3, "d": 4], ["b": 101]]
355+
],
356+
expected: .success([["a": 100], ["b": 101], ["c": 3, "d": 4]])
357+
),
358+
BuiltinTests.TestCase(
359+
description: "set",
360+
name: "sort",
361+
args: [.set([1, 100, 2])],
362+
expected: .success([1, 2, 100])
363+
),
364+
BuiltinTests.TestCase(
365+
description: "string set",
366+
name: "sort",
367+
args: [.set(["b", "a"])],
368+
expected: .success(["a", "b"])
369+
),
370+
BuiltinTests.TestCase(
371+
description: "empty set",
372+
name: "sort",
373+
args: [.set([])],
374+
expected: .success([])
375+
),
376+
BuiltinTests.TestCase(
377+
description: "set of objects",
378+
name: "sort",
379+
args: [
380+
.set([["a": 1], ["a": 100], ["a": 3]])
381+
],
382+
expected: .success([["a": 1], ["a": 3], ["a": 100]])
383+
),
384+
BuiltinTests.TestCase(
385+
description: "set of objects with different keys",
386+
name: "sort",
387+
args: [
388+
.set([["a": 100], ["c": 3, "d": 4], ["b": 101]])
389+
],
390+
// 2nd element has largest key
391+
expected: .success([["a": 100], ["b": 101], ["c": 3, "d": 4]])
392+
),
393+
BuiltinTests.TestCase(
394+
description: "array of different types",
395+
name: "sort",
396+
args: [
397+
[[1, 100, 0], .object(["z": 999]), .set([0]), [999], "10000"]
398+
],
399+
expected: .success(["10000", [1, 100, 0], [999], .object(["z": 999]), .set([0])])
400+
),
401+
]
402+
237403
static var allTests: [BuiltinTests.TestCase] {
238404
[
239405
BuiltinTests.generateFailureTests(
@@ -256,6 +422,28 @@ extension BuiltinTests.AggregatesTests {
256422
allowedArgTypes: ["array", "set"],
257423
generateNumberOfArgsTest: true),
258424
minTests,
425+
426+
BuiltinTests.generateFailureTests(
427+
builtinName: "sum", sampleArgs: [[]],
428+
argIndex: 0, argName: "collection",
429+
allowedArgTypes: ["array", "set"],
430+
generateNumberOfArgsTest: true),
431+
sumTests,
432+
433+
BuiltinTests.generateFailureTests(
434+
builtinName: "product", sampleArgs: [[]],
435+
argIndex: 0, argName: "collection",
436+
allowedArgTypes: ["array", "set"],
437+
generateNumberOfArgsTest: true),
438+
productTests,
439+
440+
BuiltinTests.generateFailureTests(
441+
builtinName: "sort", sampleArgs: [[]],
442+
argIndex: 0, argName: "collection",
443+
allowedArgTypes: ["array", "set"],
444+
generateNumberOfArgsTest: true),
445+
sortTests,
446+
259447
].flatMap { $0 }
260448
}
261449

0 commit comments

Comments
 (0)