Skip to content

int8 quantization with FSDP for inference error #2127

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
Andy0422 opened this issue Apr 25, 2025 · 0 comments
Open

int8 quantization with FSDP for inference error #2127

Andy0422 opened this issue Apr 25, 2025 · 0 comments

Comments

@Andy0422
Copy link

Andy0422 commented Apr 25, 2025

Can FSDP work with torchao in inference?

I would like to employ the torchao to get int8 model, and with FSDP to save memory.

The following code is a tiny toy to test this goal,

`import torch
import torch.nn as nn
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight, float8_dynamic_activation_float8_weight, int8_weight_only
import copy
from torch.distributed._composable.fsdp import MixedPrecisionPolicy, fully_shard

class FFN(nn.Module):
def init(self, input_dim, hidden_dim, output_dim):
super(FFN, self).init()
self.linear1 = nn.Linear(input_dim, hidden_dim)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(hidden_dim, output_dim)

def forward(self, x):
    x = self.linear1(x)
    x = self.relu(x)
    x = self.linear2(x)
    return x

weight_path = "xxx/ffn_weights.pth"

dist.init_process_group(backend='nccl')

input_dim = 10
hidden_dim = 20
output_dim = 10

base_model = FFN(input_dim, hidden_dim, output_dim).to(torch.cuda.current_device())
base_model.load_state_dict(torch.load(weight_path))
print("model structure", base_model)
fsdp_model = copy.deepcopy(base_model)

for name, module in base_model.named_modules():
if isinstance(module, nn.Linear):
print(f"Before quantization: {name}, Size = {module.weight.size()}, Stride = {module.weight.stride()}")

quantize_(base_model, int8_dynamic_activation_int8_weight())

print("q_model", base_model)

from torchao.quantization.quant_api import (
quantize_,
Int8DynamicActivationInt8WeightConfig,
Int4WeightOnlyConfig,
Int8WeightOnlyConfig
)
quantize_(base_model, Int8DynamicActivationInt8WeightConfig())
print("q_model_new_api", base_model)

for name, module in base_model.named_modules():
if isinstance(module, nn.Linear):
# print(f"Before quantization: {name}, Size = {module.weight.size()}, Stride = {module.weight.stride()}")
# setattr(model, name, quantize_(module, int8_dynamic_activation_int8_weight()))
print(f"after quantization: {name}, Size = {module.weight.size()}, Stride = {module.weight.stride()}")

for name, param in base_model.named_parameters():
print(f"Parameter: {name}, requires_grad={param.requires_grad}")

for param in base_model.parameters():
param.requires_grad = False

for name, param in base_model.named_parameters():
print(f"Parameter: {name}, requires_grad={param.requires_grad}")

wrap_policy = ModuleWrapPolicy({nn.Linear})

model = FSDP(base_model, auto_wrap_policy=wrap_policy,

use_orig_params=True)

fully_shard(base_model)
`

then, get the error,

[rank1]: Traceback (most recent call last):
[rank1]: File "fsdp_test.py", line 77, in
[rank1]: fully_shard(base_model)
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_composable/contract.py", line 125, in wrapper
[rank1]: updated = func(inp_module, *args, **kwargs)
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_composable/fsdp/fully_shard.py", line 129, in fully_shard
[rank1]: state._fsdp_param_group = FSDPParamGroup(
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 114, in init
[rank1]: self.fsdp_params = [
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 115, in
[rank1]: FSDPParam(
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_composable/fsdp/_fsdp_param.py", line 226, in init
[rank1]: self._init_sharded_param(param, device)
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank1]: return func(*args, **kwargs)
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_composable/fsdp/_fsdp_param.py", line 310, in _init_sharded_param
[rank1]: chunks = _chunk_with_empty(param_data, shard_world_size, dim=0)
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_composable/fsdp/fsdp_common.py", line 94, in chunk_with_empty
[rank1]: chunks = list(torch.chunk(tensor, num_chunks, dim=dim))
[rank1]: File "/ao/torchao/utils.py", line 425, in dispatch__torch_function

[rank1]: return func(*args, **kwargs)
[rank1]: File "/ao/torchao/utils.py", line 444, in dispatch__torch_dispatch

[rank1]: raise NotImplementedError(
[rank1]: NotImplementedError: LinearActivationQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.split', overload='Tensor')>, types=(<class 'torchao.quantization.linear_activation_quantized_tensor.LinearActivationQuantizedTensor'>,), arg_types=(<class 'torchao.quantization.linear_activation_quantized_tensor.LinearActivationQuantizedTensor'>, <class 'int'>), kwarg_types={}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant