Skip to content

Commit 3d896e9

Browse files
author
David Motsonashvili
committed
Add JsonSchema types, typeconversions, and tests
1 parent 28efb77 commit 3d896e9

File tree

9 files changed

+190
-18
lines changed

9 files changed

+190
-18
lines changed

firebase-ai/src/main/kotlin/com/google/firebase/ai/common/APIController.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ import kotlinx.coroutines.flow.map
6767
import kotlinx.coroutines.launch
6868
import kotlinx.coroutines.withTimeout
6969
import kotlinx.serialization.ExperimentalSerializationApi
70+
import kotlinx.serialization.json.ClassDiscriminatorMode
7071
import kotlinx.serialization.json.Json
7172

7273
@OptIn(ExperimentalSerializationApi::class)
@@ -75,6 +76,7 @@ internal val JSON = Json {
7576
prettyPrint = false
7677
isLenient = true
7778
explicitNulls = false
79+
classDiscriminatorMode = ClassDiscriminatorMode.NONE
7880
}
7981

8082
/**

firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Content.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ constructor(public val role: String? = "user", public val parts: List<Part>) {
8585
}
8686

8787
@OptIn(ExperimentalSerializationApi::class)
88-
internal fun toInternal() = Internal(this.role ?: "user", this.parts.map { it.toInternal() })
88+
internal fun toInternal() =
89+
Internal(this.role ?: "user", this.parts.map { it.toInternalOpenApi() })
8990

9091
@ExperimentalSerializationApi
9192
@Serializable

firebase-ai/src/main/kotlin/com/google/firebase/ai/type/FunctionDeclaration.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ public class FunctionDeclaration(
6161
internal val schema: Schema =
6262
Schema.obj(properties = parameters, optionalProperties = optionalParameters, nullable = false)
6363

64-
internal fun toInternal() = Internal(name, description, schema.toInternal())
64+
internal fun toInternal() = Internal(name, description, schema.toInternalOpenApi())
6565

6666
@Serializable
6767
internal data class Internal(

firebase-ai/src/main/kotlin/com/google/firebase/ai/type/GenerationConfig.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ private constructor(
200200
frequencyPenalty = frequencyPenalty,
201201
presencePenalty = presencePenalty,
202202
responseMimeType = responseMimeType,
203-
responseSchema = responseSchema?.toInternal(),
203+
responseSchema = responseSchema?.toInternalOpenApi(),
204204
responseModalities = responseModalities?.map { it.toInternal() },
205205
thinkingConfig = thinkingConfig?.toInternal()
206206
)

firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Part.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ internal object PartSerializer :
329329
}
330330
}
331331

332-
internal fun Part.toInternal(): InternalPart {
332+
internal fun Part.toInternalOpenApi(): InternalPart {
333333
return when (this) {
334334
is TextPart -> TextPart.Internal(text, isThought, thoughtSignature)
335335
is ImagePart ->

firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Schema.kt

Lines changed: 113 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -322,46 +322,149 @@ internal constructor(
322322
public fun anyOf(schemas: List<Schema>): Schema = Schema(type = "ANYOF", anyOf = schemas)
323323
}
324324

325-
internal fun toInternal(): Internal {
325+
internal fun toInternalOpenApi(): InternalOpenAPI {
326326
val cleanedType =
327327
if (type == "ANYOF") {
328328
null
329329
} else {
330330
type
331331
}
332-
return Internal(
332+
return InternalOpenAPI(
333333
cleanedType,
334334
description,
335335
format,
336336
nullable,
337337
enum,
338-
properties?.mapValues { it.value.toInternal() },
338+
properties?.mapValues { it.value.toInternalOpenApi() },
339339
required,
340-
items?.toInternal(),
340+
items?.toInternalOpenApi(),
341341
title,
342342
minItems,
343343
maxItems,
344344
minimum,
345345
maximum,
346-
anyOf?.map { it.toInternal() },
346+
anyOf?.map { it.toInternalOpenApi() },
347347
)
348348
}
349349

350+
internal fun toInternalJson(): InternalJson {
351+
val outType =
352+
if (type == "ANYOF" || (type == "STRING" && format == "enum")) {
353+
null
354+
} else {
355+
type.lowercase()
356+
}
357+
358+
val (outMinimum, outMaximum) =
359+
if (outType == "integer" && format == "int32") {
360+
(minimum ?: Integer.MIN_VALUE.toDouble()) to (maximum ?: Integer.MAX_VALUE.toDouble())
361+
} else {
362+
minimum to maximum
363+
}
364+
365+
val outFormat =
366+
if (
367+
(outType == "integer" && format == "int32") ||
368+
(outType == "number" && format == "float") ||
369+
format == "enum"
370+
) {
371+
null
372+
} else {
373+
format
374+
}
375+
376+
if (nullable == true) {
377+
return InternalJsonNullable(
378+
outType?.let { listOf(it, "null") },
379+
description,
380+
outFormat,
381+
enum?.let {
382+
buildList {
383+
addAll(it)
384+
add("null")
385+
}
386+
},
387+
properties?.mapValues { it.value.toInternalJson() },
388+
required,
389+
items?.toInternalJson(),
390+
title,
391+
minItems,
392+
maxItems,
393+
outMinimum,
394+
outMaximum,
395+
anyOf?.map { it.toInternalJson() },
396+
)
397+
}
398+
return InternalJsonNonNull(
399+
outType,
400+
description,
401+
outFormat,
402+
enum,
403+
properties?.mapValues { it.value.toInternalJson() },
404+
required,
405+
items?.toInternalJson(),
406+
title,
407+
minItems,
408+
maxItems,
409+
outMinimum,
410+
outMaximum,
411+
anyOf?.map { it.toInternalJson() },
412+
)
413+
}
414+
415+
@Serializable internal sealed interface Internal
416+
350417
@Serializable
351-
internal data class Internal(
418+
internal data class InternalOpenAPI(
352419
val type: String? = null,
353420
val description: String? = null,
354421
val format: String? = null,
355422
val nullable: Boolean? = false,
356423
val enum: List<String>? = null,
357-
val properties: Map<String, Internal>? = null,
424+
val properties: Map<String, InternalOpenAPI>? = null,
425+
val required: List<String>? = null,
426+
val items: InternalOpenAPI? = null,
427+
val title: String? = null,
428+
val minItems: Int? = null,
429+
val maxItems: Int? = null,
430+
val minimum: Double? = null,
431+
val maximum: Double? = null,
432+
val anyOf: List<InternalOpenAPI>? = null,
433+
) : Internal
434+
435+
@Serializable internal sealed interface InternalJson : Internal
436+
437+
@Serializable
438+
internal data class InternalJsonNonNull(
439+
val type: String? = null,
440+
val description: String? = null,
441+
val format: String? = null,
442+
val enum: List<String>? = null,
443+
val properties: Map<String, InternalJson>? = null,
444+
val required: List<String>? = null,
445+
val items: InternalJson? = null,
446+
val title: String? = null,
447+
val minItems: Int? = null,
448+
val maxItems: Int? = null,
449+
val minimum: Double? = null,
450+
val maximum: Double? = null,
451+
val anyOf: List<InternalJson>? = null,
452+
) : InternalJson
453+
454+
@Serializable
455+
internal data class InternalJsonNullable(
456+
val type: List<String>? = null,
457+
val description: String? = null,
458+
val format: String? = null,
459+
val enum: List<String>? = null,
460+
val properties: Map<String, InternalJson>? = null,
358461
val required: List<String>? = null,
359-
val items: Internal? = null,
462+
val items: InternalJson? = null,
360463
val title: String? = null,
361464
val minItems: Int? = null,
362465
val maxItems: Int? = null,
363466
val minimum: Double? = null,
364467
val maximum: Double? = null,
365-
val anyOf: List<Internal>? = null,
366-
)
468+
val anyOf: List<InternalJson>? = null,
469+
) : InternalJson
367470
}

firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Type.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,6 @@ internal data class GRpcErrorResponse(val error: GRpcError) : Response {
4242
}
4343
}
4444

45-
internal fun JSONObject.toInternal() = Json.decodeFromString<JsonObject>(toString())
45+
internal fun JSONObject.toInternalOpenApi() = Json.decodeFromString<JsonObject>(toString())
4646

4747
internal fun JsonObject.toPublic() = JSONObject(toString())

firebase-ai/src/test/java/com/google/firebase/ai/SchemaTests.kt

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ package com.google.firebase.ai
1919
import com.google.firebase.ai.type.Schema
2020
import com.google.firebase.ai.type.StringFormat
2121
import io.kotest.assertions.json.shouldEqualJson
22+
import java.io.File
2223
import kotlinx.serialization.encodeToString
24+
import kotlinx.serialization.json.ClassDiscriminatorMode
2325
import kotlinx.serialization.json.Json
2426
import org.junit.Test
2527

@@ -93,7 +95,7 @@ internal class SchemaTests {
9395
"""
9496
.trimIndent()
9597

96-
Json.encodeToString(schemaDeclaration.toInternal()).shouldEqualJson(expectedJson)
98+
Json.encodeToString(schemaDeclaration.toInternalOpenApi()).shouldEqualJson(expectedJson)
9799
}
98100

99101
@Test
@@ -216,6 +218,70 @@ internal class SchemaTests {
216218
"""
217219
.trimIndent()
218220

219-
Json.encodeToString(schemaDeclaration.toInternal()).shouldEqualJson(expectedJson)
221+
Json.encodeToString(schemaDeclaration.toInternalOpenApi()).shouldEqualJson(expectedJson)
220222
}
223+
224+
@Test
225+
fun `schema encoding openAPI spec test`() {
226+
val expectedSerialization = getSchemaJson("open-api-schema.json")
227+
val serializedSchema = JSON_ENCODER.encodeToString(TEST_SCHEMA.toInternalOpenApi())
228+
serializedSchema.shouldEqualJson(expectedSerialization)
229+
}
230+
231+
@Test
232+
fun `schema encoding jsonSchema spec test`() {
233+
val expectedSerialization = getSchemaJson("json-schema.json")
234+
val serializedSchema = JSON_ENCODER.encodeToString(TEST_SCHEMA.toInternalJson())
235+
serializedSchema.shouldEqualJson(expectedSerialization)
236+
}
237+
238+
internal fun getSchemaJson(filename: String): String {
239+
return File("src/test/resources/vertexai-sdk-test-data/mock-responses/schema/${filename}")
240+
.readText()
241+
}
242+
243+
private val JSON_ENCODER = Json { classDiscriminatorMode = ClassDiscriminatorMode.NONE }
244+
245+
private val TEST_SCHEMA =
246+
Schema.obj(
247+
properties =
248+
mapOf(
249+
"integerTest" to Schema.integer(title = "integerTest", nullable = true),
250+
"longTest" to
251+
Schema.long(
252+
title = "longTest",
253+
nullable = false,
254+
minimum = 0.0,
255+
maximum = 5.0,
256+
description = "a test long"
257+
),
258+
"floatTest" to Schema.float(title = "floatTest", nullable = false),
259+
"doubleTest" to Schema.double(title = "doubleTest", nullable = true),
260+
"listTest" to
261+
Schema.array(
262+
items = Schema.integer(nullable = false),
263+
title = "listTest",
264+
nullable = false,
265+
minItems = 0,
266+
maxItems = 5
267+
),
268+
"booleanTest" to Schema.boolean(title = "booleanTest", nullable = false),
269+
"stringTest" to
270+
Schema.string(title = "stringTest", format = StringFormat.Custom("email")),
271+
"objTest" to
272+
Schema.obj(
273+
properties =
274+
mapOf(
275+
"testInt" to Schema.integer(title = "testInt", nullable = false),
276+
),
277+
title = "objTest",
278+
description = "class kdoc should be used if property kdocs aren't present",
279+
nullable = false
280+
),
281+
"enumTest" to Schema.enumeration(values = listOf("val1", "val2", "val3"))
282+
),
283+
optionalProperties = listOf("booleanTest"),
284+
description = "A test kdoc",
285+
nullable = false
286+
)
221287
}

firebase-ai/src/test/java/com/google/firebase/ai/SerializationTests.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ internal class SerializationTests {
437437
}
438438
"""
439439
.trimIndent()
440-
val actualJson = descriptorToJson(Schema.Internal.serializer().descriptor)
440+
val actualJson = descriptorToJson(Schema.InternalOpenAPI.serializer().descriptor)
441441
expectedJsonAsString shouldEqualJson actualJson.toString()
442442
}
443443

0 commit comments

Comments
 (0)