diff --git a/test/cases/nvidia-inference/bert_inference_test.go b/test/cases/nvidia-inference/bert_inference_test.go index 5f7ef251f..8e0844a99 100644 --- a/test/cases/nvidia-inference/bert_inference_test.go +++ b/test/cases/nvidia-inference/bert_inference_test.go @@ -38,7 +38,7 @@ func TestBertInference(t *testing.T) { WithLabel("suite", "nvidia"). WithLabel("hardware", "gpu"). Setup(func(ctx context.Context, t *testing.T, cfg *envconf.Config) context.Context { - if *bertInferenceImage == "" { + if testConfig.BertInferenceImage == "" { t.Fatalf("[ERROR] bertInferenceImage must be set") } @@ -47,9 +47,9 @@ func TestBertInference(t *testing.T) { renderedBertInferenceManifest, err = fwext.RenderManifests( bertInferenceManifest, bertInferenceManifestTplVars{ - BertInferenceImage: *bertInferenceImage, - InferenceMode: *inferenceMode, - GPUPerNode: fmt.Sprintf("%d", *gpuRequested), + BertInferenceImage: testConfig.BertInferenceImage, + InferenceMode: testConfig.InferenceMode, + GPUPerNode: fmt.Sprintf("%d", testConfig.GpuRequested), }, ) if err != nil { diff --git a/test/cases/nvidia-inference/main_test.go b/test/cases/nvidia-inference/main_test.go index fcb19cc51..fd10fabdb 100644 --- a/test/cases/nvidia-inference/main_test.go +++ b/test/cases/nvidia-inference/main_test.go @@ -5,17 +5,17 @@ package inference import ( "context" _ "embed" - "flag" "fmt" "log" "os" + "os/signal" "slices" "testing" "time" fwext "github.com/aws/aws-k8s-tester/internal/e2e" + "github.com/aws/aws-k8s-tester/test/common" "github.com/aws/aws-k8s-tester/test/manifests" - appsv1 "k8s.io/api/apps/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/kubernetes" "sigs.k8s.io/e2e-framework/klient/wait" @@ -23,51 +23,71 @@ import ( "sigs.k8s.io/e2e-framework/pkg/envconf" ) +type TestConfig struct { + common.MetricOps + BertInferenceImage string `flag:"bertInferenceImage" desc:"BERT inference container image"` + InferenceMode string `flag:"inferenceMode" desc:"Inference mode for BERT (throughput or latency)"` + GpuRequested int `flag:"gpuRequested" desc:"Number of GPUs required for inference"` +} + var ( - testenv env.Environment - bertInferenceImage *string - inferenceMode *string - gpuRequested *int + testenv env.Environment + testConfig TestConfig ) func TestMain(m *testing.M) { - bertInferenceImage = flag.String("bertInferenceImage", "", "BERT inference container image") - inferenceMode = flag.String("inferenceMode", "throughput", "Inference mode for BERT (throughput or latency)") - gpuRequested = flag.Int("gpuRequested", 1, "Number of GPUs required for inference") + // Initialize testConfig with default values + testConfig = TestConfig{ + InferenceMode: "throughput", + GpuRequested: 1, + } + _, err := common.ParseFlags(&testConfig) + if err != nil { + log.Fatalf("[ERROR] Failed to parse flags: %v", err) + } cfg, err := envconf.NewFromFlags() if err != nil { - log.Fatalf("[ERROR] Failed to create test environment: %v", err) + log.Fatalf("[ERROR] Failed to initialize test environment: %v", err) } - testenv = env.NewWithConfig(cfg) - devicePluginManifests := [][]byte{ + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) + defer cancel() + testenv = env.NewWithConfig(cfg).WithContext(ctx) + + manifestsList := [][]byte{ manifests.NvidiaDevicePluginManifest, } + if len(testConfig.MetricDimensions) > 0 { + // Render CloudWatch Agent manifest with dynamic dimensions + renderedCloudWatchAgentManifest, err := manifests.RenderCloudWatchAgentManifest(testConfig.MetricDimensions) + if err != nil { + log.Printf("Warning: Failed to render CloudWatch Agent manifest: %v", err) + } + manifestsList = append(manifestsList, manifests.DCGMExporterManifest, renderedCloudWatchAgentManifest) + } + testenv.Setup( func(ctx context.Context, config *envconf.Config) (context.Context, error) { - log.Println("[INFO] Applying NVIDIA device plugin.") - if applyErr := fwext.ApplyManifests(config.Client().RESTConfig(), devicePluginManifests...); applyErr != nil { - return ctx, fmt.Errorf("failed to apply device plugin: %w", applyErr) + log.Println("[INFO] Applying manifests.") + err := fwext.ApplyManifests(config.Client().RESTConfig(), manifestsList...) + if err != nil { + return ctx, fmt.Errorf("[ERROR] Failed to apply manifests: %w", err) } + log.Println("[INFO] Successfully applied manifests.") return ctx, nil }, + common.DeployDaemonSet("nvidia-device-plugin-daemonset", "kube-system"), func(ctx context.Context, config *envconf.Config) (context.Context, error) { - ds := &appsv1.DaemonSet{ - ObjectMeta: metav1.ObjectMeta{ - Name: "nvidia-device-plugin-daemonset", - Namespace: "kube-system", - }, + if len(testConfig.MetricDimensions) > 0 { + if ctx, err := common.DeployDaemonSet("dcgm-exporter", "kube-system")(ctx, config); err != nil { + return ctx, err + } + if ctx, err := common.DeployDaemonSet("cwagent", "amazon-cloudwatch")(ctx, config); err != nil { + return ctx, err + } } - err := wait.For( - fwext.NewConditionExtension(config.Client().Resources()).DaemonSetReady(ds), - wait.WithTimeout(5*time.Minute), - ) - if err != nil { - return ctx, fmt.Errorf("device plugin daemonset not ready: %w", err) - } - log.Println("[INFO] NVIDIA device plugin is ready.") return ctx, nil }, checkGpuCapacity, @@ -75,25 +95,26 @@ func TestMain(m *testing.M) { testenv.Finish( func(ctx context.Context, config *envconf.Config) (context.Context, error) { - log.Println("[INFO] Cleaning up NVIDIA device plugin.") - slices.Reverse(devicePluginManifests) - if delErr := fwext.DeleteManifests(config.Client().RESTConfig(), devicePluginManifests...); delErr != nil { - return ctx, fmt.Errorf("failed to delete device plugin: %w", delErr) + log.Println("[INFO] Deleting manifests.") + slices.Reverse(manifestsList) + err := fwext.DeleteManifests(config.Client().RESTConfig(), manifestsList...) + if err != nil { + return ctx, fmt.Errorf("[ERROR] failed to delete manifests: %w", err) } - log.Println("[INFO] Device plugin cleanup complete.") + log.Println("[INFO] Successfully deleted manifests.") return ctx, nil }, ) exitCode := testenv.Run(m) - log.Printf("[INFO] Test environment finished with exit code %d", exitCode) + log.Printf("[INFO] Tests finished with exit code %d", exitCode) os.Exit(exitCode) } // checkGpuCapacity ensures at least one node has >= the requested number of GPUs, // and logs each node's instance type. func checkGpuCapacity(ctx context.Context, config *envconf.Config) (context.Context, error) { - log.Printf("[INFO] Validating cluster has at least %d GPU(s).", *gpuRequested) + log.Printf("[INFO] Validating cluster has at least %d GPU(s).", testConfig.GpuRequested) cs, err := kubernetes.NewForConfig(config.Client().RESTConfig()) if err != nil { @@ -110,9 +131,9 @@ func checkGpuCapacity(ctx context.Context, config *envconf.Config) (context.Cont for _, node := range nodes.Items { instanceType := node.Labels["node.kubernetes.io/instance-type"] gpuCap, ok := node.Status.Capacity["nvidia.com/gpu"] - if ok && int(gpuCap.Value()) >= *gpuRequested { + if ok && int(gpuCap.Value()) >= testConfig.GpuRequested { log.Printf("[INFO] Node %s (type: %s) meets the request of %d GPU(s).", - node.Name, instanceType, *gpuRequested) + node.Name, instanceType, testConfig.GpuRequested) return true, nil } log.Printf("[INFO] Node %s (type: %s) has no GPU capacity.", node.Name, instanceType) @@ -122,7 +143,7 @@ func checkGpuCapacity(ctx context.Context, config *envconf.Config) (context.Cont }, wait.WithTimeout(5*time.Minute), wait.WithInterval(10*time.Second)) if err != nil { - return ctx, fmt.Errorf("no node has >= %d GPU(s)", *gpuRequested) + return ctx, fmt.Errorf("no node has >= %d GPU(s)", testConfig.GpuRequested) } log.Println("[INFO] GPU capacity check passed.")