Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions test/cases/nvidia-inference/bert_inference_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func TestBertInference(t *testing.T) {
WithLabel("suite", "nvidia").
WithLabel("hardware", "gpu").
Setup(func(ctx context.Context, t *testing.T, cfg *envconf.Config) context.Context {
if *bertInferenceImage == "" {
if testConfig.BertInferenceImage == "" {
t.Fatalf("[ERROR] bertInferenceImage must be set")
}

Expand All @@ -47,9 +47,9 @@ func TestBertInference(t *testing.T) {
renderedBertInferenceManifest, err = fwext.RenderManifests(
bertInferenceManifest,
bertInferenceManifestTplVars{
BertInferenceImage: *bertInferenceImage,
InferenceMode: *inferenceMode,
GPUPerNode: fmt.Sprintf("%d", *gpuRequested),
BertInferenceImage: testConfig.BertInferenceImage,
InferenceMode: testConfig.InferenceMode,
GPUPerNode: fmt.Sprintf("%d", testConfig.GpuRequested),
},
)
if err != nil {
Expand Down
97 changes: 59 additions & 38 deletions test/cases/nvidia-inference/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,95 +5,116 @@ package inference
import (
"context"
_ "embed"
"flag"
"fmt"
"log"
"os"
"os/signal"
"slices"
"testing"
"time"

fwext "github.com/aws/aws-k8s-tester/internal/e2e"
"github.com/aws/aws-k8s-tester/test/common"
"github.com/aws/aws-k8s-tester/test/manifests"
appsv1 "k8s.io/api/apps/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes"
"sigs.k8s.io/e2e-framework/klient/wait"
"sigs.k8s.io/e2e-framework/pkg/env"
"sigs.k8s.io/e2e-framework/pkg/envconf"
)

type TestConfig struct {
common.MetricOps
BertInferenceImage string `flag:"bertInferenceImage" desc:"BERT inference container image"`
InferenceMode string `flag:"inferenceMode" desc:"Inference mode for BERT (throughput or latency)"`
GpuRequested int `flag:"gpuRequested" desc:"Number of GPUs required for inference"`
}

var (
testenv env.Environment
bertInferenceImage *string
inferenceMode *string
gpuRequested *int
testenv env.Environment
testConfig TestConfig
)

func TestMain(m *testing.M) {
bertInferenceImage = flag.String("bertInferenceImage", "", "BERT inference container image")
inferenceMode = flag.String("inferenceMode", "throughput", "Inference mode for BERT (throughput or latency)")
gpuRequested = flag.Int("gpuRequested", 1, "Number of GPUs required for inference")
// Initialize testConfig with default values
testConfig = TestConfig{
InferenceMode: "throughput",
GpuRequested: 1,
}
Comment on lines +40 to +43
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Could you comment above about this is config default value?
  • We should maintain all config default value here including InferenceMode to keep the consistency


_, err := common.ParseFlags(&testConfig)
if err != nil {
log.Fatalf("[ERROR] Failed to parse flags: %v", err)
}
cfg, err := envconf.NewFromFlags()
if err != nil {
log.Fatalf("[ERROR] Failed to create test environment: %v", err)
log.Fatalf("[ERROR] Failed to initialize test environment: %v", err)
}
testenv = env.NewWithConfig(cfg)

devicePluginManifests := [][]byte{
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
defer cancel()
testenv = env.NewWithConfig(cfg).WithContext(ctx)

manifestsList := [][]byte{
manifests.NvidiaDevicePluginManifest,
}

if len(testConfig.MetricDimensions) > 0 {
// Render CloudWatch Agent manifest with dynamic dimensions
renderedCloudWatchAgentManifest, err := manifests.RenderCloudWatchAgentManifest(testConfig.MetricDimensions)
if err != nil {
log.Printf("Warning: Failed to render CloudWatch Agent manifest: %v", err)
}
manifestsList = append(manifestsList, manifests.DCGMExporterManifest, renderedCloudWatchAgentManifest)
}

testenv.Setup(
func(ctx context.Context, config *envconf.Config) (context.Context, error) {
log.Println("[INFO] Applying NVIDIA device plugin.")
if applyErr := fwext.ApplyManifests(config.Client().RESTConfig(), devicePluginManifests...); applyErr != nil {
return ctx, fmt.Errorf("failed to apply device plugin: %w", applyErr)
log.Println("[INFO] Applying manifests.")
err := fwext.ApplyManifests(config.Client().RESTConfig(), manifestsList...)
if err != nil {
return ctx, fmt.Errorf("[ERROR] Failed to apply manifests: %w", err)
}
log.Println("[INFO] Successfully applied manifests.")
return ctx, nil
},
common.DeployDaemonSet("nvidia-device-plugin-daemonset", "kube-system"),
func(ctx context.Context, config *envconf.Config) (context.Context, error) {
ds := &appsv1.DaemonSet{
ObjectMeta: metav1.ObjectMeta{
Name: "nvidia-device-plugin-daemonset",
Namespace: "kube-system",
},
if len(testConfig.MetricDimensions) > 0 {
if ctx, err := common.DeployDaemonSet("dcgm-exporter", "kube-system")(ctx, config); err != nil {
return ctx, err
}
if ctx, err := common.DeployDaemonSet("cwagent", "amazon-cloudwatch")(ctx, config); err != nil {
return ctx, err
}
}
err := wait.For(
fwext.NewConditionExtension(config.Client().Resources()).DaemonSetReady(ds),
wait.WithTimeout(5*time.Minute),
)
if err != nil {
return ctx, fmt.Errorf("device plugin daemonset not ready: %w", err)
}
log.Println("[INFO] NVIDIA device plugin is ready.")
return ctx, nil
},
checkGpuCapacity,
)

testenv.Finish(
func(ctx context.Context, config *envconf.Config) (context.Context, error) {
log.Println("[INFO] Cleaning up NVIDIA device plugin.")
slices.Reverse(devicePluginManifests)
if delErr := fwext.DeleteManifests(config.Client().RESTConfig(), devicePluginManifests...); delErr != nil {
return ctx, fmt.Errorf("failed to delete device plugin: %w", delErr)
log.Println("[INFO] Deleting manifests.")
slices.Reverse(manifestsList)
err := fwext.DeleteManifests(config.Client().RESTConfig(), manifestsList...)
if err != nil {
return ctx, fmt.Errorf("[ERROR] failed to delete manifests: %w", err)
}
log.Println("[INFO] Device plugin cleanup complete.")
log.Println("[INFO] Successfully deleted manifests.")
return ctx, nil
},
)

exitCode := testenv.Run(m)
log.Printf("[INFO] Test environment finished with exit code %d", exitCode)
log.Printf("[INFO] Tests finished with exit code %d", exitCode)
os.Exit(exitCode)
}

// checkGpuCapacity ensures at least one node has >= the requested number of GPUs,
// and logs each node's instance type.
func checkGpuCapacity(ctx context.Context, config *envconf.Config) (context.Context, error) {
log.Printf("[INFO] Validating cluster has at least %d GPU(s).", *gpuRequested)
log.Printf("[INFO] Validating cluster has at least %d GPU(s).", testConfig.GpuRequested)

cs, err := kubernetes.NewForConfig(config.Client().RESTConfig())
if err != nil {
Expand All @@ -110,9 +131,9 @@ func checkGpuCapacity(ctx context.Context, config *envconf.Config) (context.Cont
for _, node := range nodes.Items {
instanceType := node.Labels["node.kubernetes.io/instance-type"]
gpuCap, ok := node.Status.Capacity["nvidia.com/gpu"]
if ok && int(gpuCap.Value()) >= *gpuRequested {
if ok && int(gpuCap.Value()) >= testConfig.GpuRequested {
log.Printf("[INFO] Node %s (type: %s) meets the request of %d GPU(s).",
node.Name, instanceType, *gpuRequested)
node.Name, instanceType, testConfig.GpuRequested)
return true, nil
}
log.Printf("[INFO] Node %s (type: %s) has no GPU capacity.", node.Name, instanceType)
Expand All @@ -122,7 +143,7 @@ func checkGpuCapacity(ctx context.Context, config *envconf.Config) (context.Cont
}, wait.WithTimeout(5*time.Minute), wait.WithInterval(10*time.Second))

if err != nil {
return ctx, fmt.Errorf("no node has >= %d GPU(s)", *gpuRequested)
return ctx, fmt.Errorf("no node has >= %d GPU(s)", testConfig.GpuRequested)
}

log.Println("[INFO] GPU capacity check passed.")
Expand Down