diff --git a/src/finch/julia.py b/src/finch/julia.py index f1fb19d..f391175 100644 --- a/src/finch/julia.py +++ b/src/finch/julia.py @@ -17,7 +17,7 @@ ): juliapkg.add(_FINCH_NAME, _FINCH_HASH, version=_FINCH_VERSION) -import juliacall # noqa +import juliacall as jc # noqa juliapkg.resolve() from juliacall import Main as jl # noqa diff --git a/src/finch/tensor.py b/src/finch/tensor.py index e9deb67..bc4b480 100644 --- a/src/finch/tensor.py +++ b/src/finch/tensor.py @@ -7,7 +7,7 @@ from . import dtypes as jl_dtypes from .errors import PerformanceWarning -from .julia import jl +from .julia import jc, jl from .levels import ( _Display, Dense, @@ -338,7 +338,7 @@ def todense(self) -> np.ndarray: else: # create materialized dense array shape = jl.size(obj) - dense_lvls = jl.Element(jl.default(obj)) + dense_lvls = jl.Element(jc.convert(self.dtype, jl.default(obj))) for _ in range(self.ndim): dense_lvls = jl.Dense(dense_lvls) dense_tensor = jl.Tensor(dense_lvls, obj).lvl # materialize @@ -748,7 +748,7 @@ def astype(x: Tensor, dtype: DType, /, *, copy: bool = True): else: finch_tns = x._obj.body result = jl.copyto_b( - jl.similar(finch_tns, jl.default(finch_tns), dtype), finch_tns + jl.similar(finch_tns, jc.convert(dtype, jl.default(finch_tns)), dtype), finch_tns ) return Tensor(jl.swizzle(result, *x.get_order(zero_indexing=False))) @@ -785,16 +785,7 @@ def _reduce(x: Tensor, fn: Callable, axis, dtype=None): result = fn(x._obj, dims=axis) else: result = fn(x._obj) - - if ( - jl.isa(result, jl.Finch.SwizzleArray) or - jl.isa(result, jl.Finch.Tensor) or - jl.isa(result, jl.Finch.LazyTensor) - ): - result = Tensor(result) - else: - result = np.array(result) - return result + return Tensor(result) def sum( diff --git a/tests/test_ops.py b/tests/test_ops.py index 2649546..6f8ee44 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -167,10 +167,7 @@ def test_reductions(arr3d, func_name, axis, dtype): actual = getattr(finch, func_name)(A_finch, axis=axis) expected = getattr(np, func_name)(arr3d, axis=axis) - if isinstance(actual, finch.Tensor): - actual = actual.todense() - - assert_equal(actual, expected) + assert_equal(actual.todense(), expected) @pytest.mark.parametrize( diff --git a/tests/test_sparse.py b/tests/test_sparse.py index 25f4a86..e5ab04d 100644 --- a/tests/test_sparse.py +++ b/tests/test_sparse.py @@ -142,7 +142,7 @@ def test_permute_dims(arr3d, permutation, order): @pytest.mark.parametrize("order", ["C", "F"]) def test_astype(arr3d, order): - arr = np.array(arr3d, order=order) + arr = np.array(arr3d, order=order, dtype=np.int64) storage = finch.Storage( finch.Dense(finch.SparseList(finch.SparseList(finch.Element(np.int64(0))))), order=order, @@ -150,15 +150,21 @@ def test_astype(arr3d, order): arr_finch = finch.Tensor(arr).to_device(storage) result = finch.astype(arr_finch, finch.int64) - assert_equal(result.todense(), arr) - assert not arr_finch is result + assert not result is arr_finch + result = result.todense() + assert_equal(result, arr) + assert result.dtype == arr.dtype result = finch.astype(arr_finch, finch.int64, copy=False) - assert_equal(result.todense(), arr) - assert arr_finch is result - - result = finch.astype(arr_finch, finch.float32) - assert_equal(result.todense(), arr.astype(np.float32)) + assert result is arr_finch + result = result.todense() + assert_equal(result, arr) + assert result.dtype == arr.dtype + + result = finch.astype(arr_finch, finch.float32).todense() + arr = arr.astype(np.float32) + assert_equal(result, arr) + assert result.dtype == arr.dtype with pytest.raises( ValueError, match="Unable to avoid a copy while casting in no-copy mode."