Skip to content

Commit 124bde9

Browse files
authored
Bring QAT POC back to a functional state (#19290)
1 parent 6226c5f commit 124bde9

File tree

4 files changed

+27
-51
lines changed

4 files changed

+27
-51
lines changed

orttraining/orttraining/test/python/qat_poc_example/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ We use `onnxruntime.training.onnxblock` to perform the above operations to get t
4848

4949
> **_NOTE:_** As of this writing, ORT does not have its own `"Observers"`. Instead, we rely on the `onnxruntime.quantization` tool to quantize the model and give us an initial estimate of the quantization parameters using its calibration process. Here the calibration process is used as a substitute for the observers to present the POC.
5050
51-
> **_NOTE:_** Typically, the weights in the statically quantized onnx model is associated with a DQ node only (not the QDQ pair) since weights are quantized. However, QAT requires weights and biases to be non quantized. We ensure that the weights have dedicated QDQ pair by passing in the flag AddQDQPairToWeight=True`
51+
> **_NOTE:_** Typically, the weights in the statically quantized onnx model is associated with a DQ node only (not the QDQ pair) since weights are quantized. However, QAT requires weights and biases to be non quantized. We ensure that the weights have dedicated QDQ pair by passing in the flag `AddQDQPairToWeight=True`
5252
5353
> **_NOTE:_** Typically, the bias term in the statically quantized onnx model is associated with a DQ node only (not the QDQ pair) since it is quantized as int32 as opposed to int8. So, we disable quantizing the bias term using the flag QuantizeBias=False`
5454

orttraining/orttraining/test/python/qat_poc_example/model.py

Lines changed: 20 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import onnx
66
import torch
77

8-
import onnxruntime.training.onnxblock as onnxblock
8+
from onnxruntime.training import artifacts
99

1010

1111
class MNIST(torch.nn.Module):
@@ -96,42 +96,26 @@ def create_training_artifacts(model_path, artifacts_dir, model_prefix):
9696
4. The checkpoint file
9797
"""
9898

99-
class MNISTWithLoss(onnxblock.TrainingModel):
100-
def __init__(self):
101-
super().__init__()
102-
self.loss = onnxblock.loss.CrossEntropyLoss()
103-
104-
def build(self, output_name):
105-
return self.loss(output_name)
106-
107-
mnist_with_loss = MNISTWithLoss()
108-
onnx_model, eval_model, optimizer_model = onnx.load(model_path), None, None
109-
110-
# Build the training and eval graphs
111-
logging.info("Using onnxblock to create the training artifacts.")
112-
with onnxblock.onnx_model(onnx_model) as model_accessor:
113-
_ = mnist_with_loss(onnx_model.graph.output[0].name)
114-
eval_model = model_accessor.eval_model
115-
116-
# Build the optimizer graph
117-
optimizer = onnxblock.optim.AdamW()
118-
with onnxblock.onnx_model() as accessor:
119-
_ = optimizer(mnist_with_loss.parameters())
120-
optimizer_model = accessor.model
99+
onnx_model = onnx.load(model_path)
100+
101+
requires_grad = [
102+
param.name
103+
for param in onnx_model.graph.initializer
104+
if (not param.name.endswith("_scale") and not param.name.endswith("_zero_point"))
105+
]
106+
artifacts.generate_artifacts(
107+
onnx_model,
108+
requires_grad=requires_grad,
109+
loss=artifacts.LossType.CrossEntropyLoss,
110+
optimizer=artifacts.OptimType.AdamW,
111+
artifact_directory=artifacts_dir,
112+
prefix=model_prefix,
113+
)
121114

122115
# Create the training artifacts
123-
train_model_path = os.path.join(artifacts_dir, f"{model_prefix}_train.onnx")
124-
logging.info(f"Saving the training model to {train_model_path}.")
125-
onnx.save(onnx_model, train_model_path)
126-
eval_model_path = os.path.join(artifacts_dir, f"{model_prefix}_eval.onnx")
127-
logging.info(f"Saving the eval model to {eval_model_path}.")
128-
onnx.save(eval_model, eval_model_path)
129-
optimizer_model_path = os.path.join(artifacts_dir, f"{model_prefix}_optimizer.onnx")
130-
logging.info(f"Saving the optimizer model to {optimizer_model_path}.")
131-
onnx.save(optimizer_model, optimizer_model_path)
132-
trainable_params, non_trainable_params = mnist_with_loss.parameters()
133-
checkpoint_path = os.path.join(artifacts_dir, f"{model_prefix}_checkpoint.ckpt")
134-
logging.info(f"Saving the checkpoint to {checkpoint_path}.")
135-
onnxblock.save_checkpoint((trainable_params, non_trainable_params), checkpoint_path)
116+
train_model_path = os.path.join(artifacts_dir, f"{model_prefix}training_model.onnx")
117+
eval_model_path = os.path.join(artifacts_dir, f"{model_prefix}eval_model.onnx")
118+
optimizer_model_path = os.path.join(artifacts_dir, f"{model_prefix}optimizer_model.onnx")
119+
checkpoint_path = os.path.join(artifacts_dir, f"{model_prefix}checkpoint")
136120

137121
return train_model_path, eval_model_path, optimizer_model_path, checkpoint_path

orttraining/orttraining/test/python/qat_poc_example/qat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
)
4747

4848
logging.info("Preparing the training artifacts for QAT.")
49-
training_model_name = "mnist_qat"
49+
training_model_name = "mnist_qat_"
5050
artifacts_dir = os.path.join(model_dir, "training_artifacts")
5151
utils.makedir(artifacts_dir)
5252
training_artifacts = create_training_artifacts(

orttraining/orttraining/test/python/qat_poc_example/train.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,10 @@ def _train_epoch(model, optimizer, train_loader):
2626
model.train()
2727
cumulative_loss = 0
2828
for data, target in train_loader:
29-
forward_inputs = [
30-
data.reshape(len(data), 784).numpy(),
31-
target.numpy().astype(np.int32),
32-
]
33-
train_loss = model(forward_inputs)
29+
train_loss = model(data.reshape(len(data), 784).numpy(), target.numpy().astype(np.int64))
3430
optimizer.step()
3531
model.lazy_reset_grad()
36-
cumulative_loss += train_loss[0]
32+
cumulative_loss += train_loss
3733

3834
return cumulative_loss / len(train_loader)
3935

@@ -43,12 +39,8 @@ def _eval(model, test_loader):
4339
model.eval()
4440
cumulative_loss = 0
4541
for data, target in test_loader:
46-
forward_inputs = [
47-
data.reshape(len(data), 784).numpy(),
48-
target.numpy().astype(np.int32),
49-
]
50-
test_loss = model(forward_inputs)
51-
cumulative_loss += test_loss[0]
42+
test_loss = model(data.reshape(len(data), 784).numpy(), target.numpy().astype(np.int64))
43+
cumulative_loss += test_loss
5244

5345
return cumulative_loss / len(test_loader)
5446

@@ -65,7 +57,7 @@ def train_model(qat_train_model, qat_eval_model, qat_optimizer_model, qat_checkp
6557
train_loader, test_loader = _get_dataloaders("data", batch_size)
6658

6759
# Load the checkpoint state.
68-
state = orttraining.CheckpointState(qat_checkpoint)
60+
state = orttraining.CheckpointState.load_checkpoint(qat_checkpoint)
6961

7062
# Create the training module.
7163
model = orttraining.Module(qat_train_model, state, qat_eval_model)

0 commit comments

Comments
 (0)