@@ -1170,6 +1170,8 @@ def floor(x):
1170
1170
1171
1171
def gather (params , indices , axis = None ):
1172
1172
op = P .Gather ()
1173
+ if axis is None :
1174
+ axis = 0
1173
1175
return op (params , indices , axis )
1174
1176
1175
1177
@@ -1590,10 +1592,7 @@ def reduce_std(x, axis=None, keepdims=False):
1590
1592
1591
1593
1592
1594
def reduce_sum (x , axis = None , keepdims = False ):
1593
- op = P .ReduceSum (keep_dims = keepdims )
1594
- if axis is None :
1595
- return op (x )
1596
- return op (x , axis = axis )
1595
+ return msnp .sum (x , axis = axis , keepdims = keepdims )
1597
1596
1598
1597
1599
1598
def reduce_variance (x , axis = None , keepdims = False ):
@@ -1729,11 +1728,15 @@ def tanh(x):
1729
1728
1730
1729
def any (x , axis = None , keepdims = False ):
1731
1730
op = P .ReduceAny (keep_dims = keepdims )
1731
+ if axis is None :
1732
+ return op (x )
1732
1733
return op (x , axis )
1733
1734
1734
1735
1735
1736
def all (x , axis = None , keepdims = False ):
1736
1737
op = P .ReduceAll (keep_dims = keepdims )
1738
+ if axis is None :
1739
+ return op (x )
1737
1740
return op (x , axis )
1738
1741
1739
1742
@@ -1779,8 +1782,7 @@ def zeros_like(x, dtype=None):
1779
1782
1780
1783
1781
1784
def squeeze (x , axis = None ):
1782
- op = P .Squeeze (axis )
1783
- return op (x )
1785
+ return msnp .squeeze (x , axis )
1784
1786
1785
1787
1786
1788
def unsorted_segment_sum (x , segment_ids , num_segments ):
@@ -1792,7 +1794,7 @@ def unsorted_segment_sum(x, segment_ids, num_segments):
1792
1794
def unsorted_segment_mean (x , segment_ids , num_segments ):
1793
1795
segment_ids = ms .Tensor (segment_ids )
1794
1796
op = P .UnsortedSegmentSum ()
1795
- x_one = msnp .ones_like (x , dtype = x .dtype )
1797
+ x_one = msnp .ones_like (x , dtype = x .dtype )
1796
1798
sum = op (x , segment_ids , num_segments )
1797
1799
one = op (x_one , segment_ids , num_segments )
1798
1800
return sum / one
0 commit comments