Skip to content

Commit 8aa9a97

Browse files
Refactor prompt block into base (#3018)
* Refactor prompt block into base * Fix TypeScript errors in base-prompt-block.ts Co-Authored-By: [email protected] <[email protected]> --------- Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com>
1 parent e1f3967 commit 8aa9a97

File tree

3 files changed

+201
-343
lines changed

3 files changed

+201
-343
lines changed

ee/codegen/src/generators/base-prompt-block.ts

Lines changed: 191 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ import {
1111
VideoPromptBlock,
1212
} from "vellum-ai/api";
1313

14+
import { Json } from "./json";
15+
1416
import { VELLUM_CLIENT_MODULE_PATH } from "src/constants";
1517
import { WorkflowContext } from "src/context/workflow-context";
1618
import { AstNode } from "src/generators/extensions/ast-node";
@@ -68,7 +70,195 @@ export abstract class BasePromptBlock<
6870
this.astNode = this.generateAstNode(promptBlock);
6971
}
7072

71-
protected abstract generateAstNode(promptBlock: T): ClassInstantiation;
73+
protected generateAstNode(promptBlock: T): ClassInstantiation {
74+
switch (promptBlock.blockType) {
75+
case "JINJA":
76+
return this.generateJinjaPromptBlock(
77+
promptBlock as Extract<T, { blockType: "JINJA" }>
78+
);
79+
case "CHAT_MESSAGE":
80+
return this.generateChatMessagePromptBlock(
81+
promptBlock as Extract<T, { blockType: "CHAT_MESSAGE" }>
82+
);
83+
case "VARIABLE":
84+
return this.generateVariablePromptBlock(
85+
promptBlock as Extract<T, { blockType: "VARIABLE" }>
86+
);
87+
case "RICH_TEXT":
88+
return this.generateRichTextPromptBlock(
89+
promptBlock as Extract<T, { blockType: "RICH_TEXT" }>
90+
);
91+
case "PLAIN_TEXT":
92+
return this.generatePlainTextPromptBlock(
93+
promptBlock as Extract<T, { blockType: "PLAIN_TEXT" }>
94+
);
95+
case "AUDIO":
96+
return this.generateAudioPromptBlock(
97+
promptBlock as Extract<T, { blockType: "AUDIO" }>
98+
);
99+
case "VIDEO":
100+
return this.generateVideoPromptBlock(
101+
promptBlock as Extract<T, { blockType: "VIDEO" }>
102+
);
103+
case "IMAGE":
104+
return this.generateImagePromptBlock(
105+
promptBlock as Extract<T, { blockType: "IMAGE" }>
106+
);
107+
case "DOCUMENT":
108+
return this.generateDocumentPromptBlock(
109+
promptBlock as Extract<T, { blockType: "DOCUMENT" }>
110+
);
111+
}
112+
}
113+
114+
protected abstract generateJinjaPromptBlock(
115+
promptBlock: Extract<T, { blockType: "JINJA" }>
116+
): ClassInstantiation;
117+
protected abstract generateChatMessagePromptBlock(
118+
promptBlock: Extract<T, { blockType: "CHAT_MESSAGE" }>
119+
): ClassInstantiation;
120+
protected abstract generateVariablePromptBlock(
121+
promptBlock: Extract<T, { blockType: "VARIABLE" }>
122+
): ClassInstantiation;
123+
protected abstract generateRichTextPromptBlock(
124+
promptBlock: Extract<T, { blockType: "RICH_TEXT" }>
125+
): ClassInstantiation;
126+
protected abstract generatePlainTextPromptBlock(
127+
promptBlock: Extract<T, { blockType: "PLAIN_TEXT" }>
128+
): ClassInstantiation;
129+
protected generateAudioPromptBlock(
130+
promptBlock: Extract<T, { blockType: "AUDIO" }>
131+
): ClassInstantiation {
132+
const classArgs: MethodArgument[] = [
133+
...this.constructCommonClassArguments(promptBlock),
134+
...this.generateCommonFileInputArguments(promptBlock),
135+
];
136+
137+
const audioBlock = new ClassInstantiation({
138+
classReference: this.getPromptBlockRef(promptBlock),
139+
arguments_: classArgs,
140+
});
141+
142+
this.inheritReferences(audioBlock);
143+
return audioBlock;
144+
}
145+
146+
protected generateVideoPromptBlock(
147+
promptBlock: Extract<T, { blockType: "VIDEO" }>
148+
): ClassInstantiation {
149+
const classArgs: MethodArgument[] = [
150+
...this.constructCommonClassArguments(promptBlock),
151+
...this.generateCommonFileInputArguments(promptBlock),
152+
];
153+
154+
const videoBlock = new ClassInstantiation({
155+
classReference: this.getPromptBlockRef(promptBlock),
156+
arguments_: classArgs,
157+
});
158+
159+
this.inheritReferences(videoBlock);
160+
return videoBlock;
161+
}
162+
163+
protected generateImagePromptBlock(
164+
promptBlock: Extract<T, { blockType: "IMAGE" }>
165+
): ClassInstantiation {
166+
const classArgs: MethodArgument[] = [
167+
...this.constructCommonClassArguments(promptBlock),
168+
...this.generateCommonFileInputArguments(promptBlock),
169+
];
170+
171+
const imageBlock = new ClassInstantiation({
172+
classReference: this.getPromptBlockRef(promptBlock),
173+
arguments_: classArgs,
174+
});
175+
176+
this.inheritReferences(imageBlock);
177+
return imageBlock;
178+
}
179+
180+
protected generateDocumentPromptBlock(
181+
promptBlock: Extract<T, { blockType: "DOCUMENT" }>
182+
): ClassInstantiation {
183+
const classArgs: MethodArgument[] = [
184+
...this.constructCommonClassArguments(promptBlock),
185+
...this.generateCommonFileInputArguments(promptBlock),
186+
];
187+
188+
const documentBlock = new ClassInstantiation({
189+
classReference: this.getPromptBlockRef(promptBlock),
190+
arguments_: classArgs,
191+
});
192+
193+
this.inheritReferences(documentBlock);
194+
return documentBlock;
195+
}
196+
197+
protected getPromptBlockRef(promptBlock: T): Reference {
198+
let pathName;
199+
switch (promptBlock.blockType) {
200+
case "JINJA":
201+
pathName = "JinjaPromptBlock";
202+
break;
203+
case "CHAT_MESSAGE":
204+
pathName = "ChatMessagePromptBlock";
205+
break;
206+
case "VARIABLE":
207+
pathName = "VariablePromptBlock";
208+
break;
209+
case "RICH_TEXT":
210+
pathName = "RichTextPromptBlock";
211+
break;
212+
case "PLAIN_TEXT":
213+
pathName = "PlainTextPromptBlock";
214+
break;
215+
case "AUDIO":
216+
pathName = "AudioPromptBlock";
217+
break;
218+
case "VIDEO":
219+
pathName = "VideoPromptBlock";
220+
break;
221+
case "IMAGE":
222+
pathName = "ImagePromptBlock";
223+
break;
224+
case "DOCUMENT":
225+
pathName = "DocumentPromptBlock";
226+
break;
227+
}
228+
return new Reference({
229+
name: pathName,
230+
modulePath: VELLUM_CLIENT_MODULE_PATH,
231+
});
232+
}
233+
234+
protected generateCommonFileInputArguments(
235+
promptBlock:
236+
| Extract<T, { blockType: "AUDIO" }>
237+
| Extract<T, { blockType: "VIDEO" }>
238+
| Extract<T, { blockType: "IMAGE" }>
239+
| Extract<T, { blockType: "DOCUMENT" }>
240+
): MethodArgument[] {
241+
const classArgs: MethodArgument[] = [];
242+
243+
classArgs.push(
244+
new MethodArgument({
245+
name: "src",
246+
value: new StrInstantiation(promptBlock.src),
247+
})
248+
);
249+
250+
if (promptBlock.metadata) {
251+
const metadataJson = new Json(promptBlock.metadata);
252+
classArgs.push(
253+
new MethodArgument({
254+
name: "metadata",
255+
value: metadataJson,
256+
})
257+
);
258+
}
259+
260+
return classArgs;
261+
}
72262

73263
protected constructCommonClassArguments(promptBlock: T): MethodArgument[] {
74264
const args: MethodArgument[] = [];

0 commit comments

Comments
 (0)