@@ -388,30 +388,6 @@ def make_pta_acc_format(pta_str_list: List[str], func_name: str) -> List[str]:
388
388
return new_str_list
389
389
390
390
391
- def make_pta_acc_builder_format (pta_str_list : List [str ]) -> List [str ]:
392
- new_str_list = []
393
- for pta_str in pta_str_list :
394
- if "packed_accessor" in pta_str :
395
- match = re .search (
396
- r"([a-zA-z0-9_]*)[.]packed_accessor([3|6][2|4])<(.*)>\(\)" , pta_str
397
- )
398
- assert match is not None and len (match .groups ()) == 3
399
- tensor , acc_nbits , args = match .groups ()
400
- if "acc_type" in args :
401
- match = re .search ("at::acc_type<([a-zA-Z_0-9]*), true>" , args )
402
- assert match is not None and len (match .groups ()) == 1
403
- new_type = match .group (1 )
404
- args = re .sub ("at::acc_type<[a-zA-Z_]*, true>" , new_type , args )
405
- macro_name = "PTA_ACC_B"
406
- else :
407
- macro_name = "PTA_B"
408
- args = args .replace (", at::RestrictPtrTraits" , "" )
409
- new_str_list .append (f"{ macro_name } ({ tensor } , { args } , { acc_nbits } )" )
410
- else :
411
- new_str_list .append (pta_str )
412
- return new_str_list
413
-
414
-
415
391
def replace_pta_namespace (pta_str_list : List [str ]) -> List [str ]:
416
392
return [
417
393
pta_str .replace ("at::PackedTensorAccessor" , "pta::PackedTensorAccessor" )
@@ -455,7 +431,6 @@ def to_upper_placeholder_types(arg_str_list: List[str]) -> List[str]:
455
431
################################################################################
456
432
457
433
env .filters ["make_pta_acc_format" ] = make_pta_acc_format
458
- env .filters ["make_pta_acc_builder_format" ] = make_pta_acc_builder_format
459
434
env .filters ["replace_pta_namespace" ] = replace_pta_namespace
460
435
env .filters ["replace_placeholder_types" ] = replace_placeholder_types
461
436
env .filters ["to_upper_placeholder_types" ] = to_upper_placeholder_types
0 commit comments