You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I would like to employ the torchao to get int8 model, and with FSDP to save memory.
The following code is a tiny toy to test this goal,
`import torch
import torch.nn as nn
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight, float8_dynamic_activation_float8_weight, int8_weight_only
import copy
from torch.distributed._composable.fsdp import MixedPrecisionPolicy, fully_shard
Can FSDP work with torchao in inference?
I would like to employ the torchao to get int8 model, and with FSDP to save memory.
The following code is a tiny toy to test this goal,
`import torch
import torch.nn as nn
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight, float8_dynamic_activation_float8_weight, int8_weight_only
import copy
from torch.distributed._composable.fsdp import MixedPrecisionPolicy, fully_shard
class FFN(nn.Module):
def init(self, input_dim, hidden_dim, output_dim):
super(FFN, self).init()
self.linear1 = nn.Linear(input_dim, hidden_dim)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(hidden_dim, output_dim)
weight_path = "xxx/ffn_weights.pth"
dist.init_process_group(backend='nccl')
input_dim = 10
hidden_dim = 20
output_dim = 10
base_model = FFN(input_dim, hidden_dim, output_dim).to(torch.cuda.current_device())
base_model.load_state_dict(torch.load(weight_path))
print("model structure", base_model)
fsdp_model = copy.deepcopy(base_model)
for name, module in base_model.named_modules():
if isinstance(module, nn.Linear):
print(f"Before quantization: {name}, Size = {module.weight.size()}, Stride = {module.weight.stride()}")
quantize_(base_model, int8_dynamic_activation_int8_weight())
print("q_model", base_model)
from torchao.quantization.quant_api import (
quantize_,
Int8DynamicActivationInt8WeightConfig,
Int4WeightOnlyConfig,
Int8WeightOnlyConfig
)
quantize_(base_model, Int8DynamicActivationInt8WeightConfig())
print("q_model_new_api", base_model)
for name, module in base_model.named_modules():
if isinstance(module, nn.Linear):
# print(f"Before quantization: {name}, Size = {module.weight.size()}, Stride = {module.weight.stride()}")
# setattr(model, name, quantize_(module, int8_dynamic_activation_int8_weight()))
print(f"after quantization: {name}, Size = {module.weight.size()}, Stride = {module.weight.stride()}")
for name, param in base_model.named_parameters():
print(f"Parameter: {name}, requires_grad={param.requires_grad}")
for param in base_model.parameters():
param.requires_grad = False
for name, param in base_model.named_parameters():
print(f"Parameter: {name}, requires_grad={param.requires_grad}")
wrap_policy = ModuleWrapPolicy({nn.Linear})
model = FSDP(base_model, auto_wrap_policy=wrap_policy,
use_orig_params=True)
fully_shard(base_model)
`
then, get the error,
[rank1]: Traceback (most recent call last):
[rank1]: File "fsdp_test.py", line 77, in
[rank1]: fully_shard(base_model)
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_composable/contract.py", line 125, in wrapper
[rank1]: updated = func(inp_module, *args, **kwargs)
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_composable/fsdp/fully_shard.py", line 129, in fully_shard
[rank1]: state._fsdp_param_group = FSDPParamGroup(
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 114, in init
[rank1]: self.fsdp_params = [
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 115, in
[rank1]: FSDPParam(
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_composable/fsdp/_fsdp_param.py", line 226, in init
[rank1]: self._init_sharded_param(param, device)
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank1]: return func(*args, **kwargs)
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_composable/fsdp/_fsdp_param.py", line 310, in _init_sharded_param
[rank1]: chunks = _chunk_with_empty(param_data, shard_world_size, dim=0)
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_composable/fsdp/fsdp_common.py", line 94, in chunk_with_empty
[rank1]: chunks = list(torch.chunk(tensor, num_chunks, dim=dim))
[rank1]: File "/ao/torchao/utils.py", line 425, in dispatch__torch_function
[rank1]: return func(*args, **kwargs)
[rank1]: File "/ao/torchao/utils.py", line 444, in dispatch__torch_dispatch
[rank1]: raise NotImplementedError(
[rank1]: NotImplementedError: LinearActivationQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.split', overload='Tensor')>, types=(<class 'torchao.quantization.linear_activation_quantized_tensor.LinearActivationQuantizedTensor'>,), arg_types=(<class 'torchao.quantization.linear_activation_quantized_tensor.LinearActivationQuantizedTensor'>, <class 'int'>), kwarg_types={}
The text was updated successfully, but these errors were encountered: