|
5 | 5 | import onnx |
6 | 6 | import torch |
7 | 7 |
|
8 | | -import onnxruntime.training.onnxblock as onnxblock |
| 8 | +from onnxruntime.training import artifacts |
9 | 9 |
|
10 | 10 |
|
11 | 11 | class MNIST(torch.nn.Module): |
@@ -96,42 +96,26 @@ def create_training_artifacts(model_path, artifacts_dir, model_prefix): |
96 | 96 | 4. The checkpoint file |
97 | 97 | """ |
98 | 98 |
|
99 | | - class MNISTWithLoss(onnxblock.TrainingModel): |
100 | | - def __init__(self): |
101 | | - super().__init__() |
102 | | - self.loss = onnxblock.loss.CrossEntropyLoss() |
103 | | - |
104 | | - def build(self, output_name): |
105 | | - return self.loss(output_name) |
106 | | - |
107 | | - mnist_with_loss = MNISTWithLoss() |
108 | | - onnx_model, eval_model, optimizer_model = onnx.load(model_path), None, None |
109 | | - |
110 | | - # Build the training and eval graphs |
111 | | - logging.info("Using onnxblock to create the training artifacts.") |
112 | | - with onnxblock.onnx_model(onnx_model) as model_accessor: |
113 | | - _ = mnist_with_loss(onnx_model.graph.output[0].name) |
114 | | - eval_model = model_accessor.eval_model |
115 | | - |
116 | | - # Build the optimizer graph |
117 | | - optimizer = onnxblock.optim.AdamW() |
118 | | - with onnxblock.onnx_model() as accessor: |
119 | | - _ = optimizer(mnist_with_loss.parameters()) |
120 | | - optimizer_model = accessor.model |
| 99 | + onnx_model = onnx.load(model_path) |
| 100 | + |
| 101 | + requires_grad = [ |
| 102 | + param.name |
| 103 | + for param in onnx_model.graph.initializer |
| 104 | + if (not param.name.endswith("_scale") and not param.name.endswith("_zero_point")) |
| 105 | + ] |
| 106 | + artifacts.generate_artifacts( |
| 107 | + onnx_model, |
| 108 | + requires_grad=requires_grad, |
| 109 | + loss=artifacts.LossType.CrossEntropyLoss, |
| 110 | + optimizer=artifacts.OptimType.AdamW, |
| 111 | + artifact_directory=artifacts_dir, |
| 112 | + prefix=model_prefix, |
| 113 | + ) |
121 | 114 |
|
122 | 115 | # Create the training artifacts |
123 | | - train_model_path = os.path.join(artifacts_dir, f"{model_prefix}_train.onnx") |
124 | | - logging.info(f"Saving the training model to {train_model_path}.") |
125 | | - onnx.save(onnx_model, train_model_path) |
126 | | - eval_model_path = os.path.join(artifacts_dir, f"{model_prefix}_eval.onnx") |
127 | | - logging.info(f"Saving the eval model to {eval_model_path}.") |
128 | | - onnx.save(eval_model, eval_model_path) |
129 | | - optimizer_model_path = os.path.join(artifacts_dir, f"{model_prefix}_optimizer.onnx") |
130 | | - logging.info(f"Saving the optimizer model to {optimizer_model_path}.") |
131 | | - onnx.save(optimizer_model, optimizer_model_path) |
132 | | - trainable_params, non_trainable_params = mnist_with_loss.parameters() |
133 | | - checkpoint_path = os.path.join(artifacts_dir, f"{model_prefix}_checkpoint.ckpt") |
134 | | - logging.info(f"Saving the checkpoint to {checkpoint_path}.") |
135 | | - onnxblock.save_checkpoint((trainable_params, non_trainable_params), checkpoint_path) |
| 116 | + train_model_path = os.path.join(artifacts_dir, f"{model_prefix}training_model.onnx") |
| 117 | + eval_model_path = os.path.join(artifacts_dir, f"{model_prefix}eval_model.onnx") |
| 118 | + optimizer_model_path = os.path.join(artifacts_dir, f"{model_prefix}optimizer_model.onnx") |
| 119 | + checkpoint_path = os.path.join(artifacts_dir, f"{model_prefix}checkpoint") |
136 | 120 |
|
137 | 121 | return train_model_path, eval_model_path, optimizer_model_path, checkpoint_path |
0 commit comments