Skip to content

Commit d6bf440

Browse files
committedJan 17, 2025
rework default scheduler settings
1 parent cc6b4ea commit d6bf440

File tree

4 files changed

+31
-24
lines changed

4 files changed

+31
-24
lines changed
 

‎Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
99
OhMyThreads = "67456a42-1dca-4109-a031-0a68de7e3ad5"
1010
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
1111
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
12+
ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63"
1213
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1314
Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67"
1415
TensorKitSectors = "13a9c161-d5da-41f0-bcbd-e1a08ae0647f"
@@ -35,6 +36,7 @@ LinearAlgebra = "1"
3536
OhMyThreads = "0.7.0"
3637
PackageExtensionCompat = "1"
3738
Random = "1"
39+
ScopedValues = "1.3.0"
3840
SparseArrays = "1"
3941
Strided = "2"
4042
TensorKitSectors = "0.1"

‎src/TensorKit.jl

+2
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ const TO = TensorOperations
102102

103103
using LRUCache
104104
using OhMyThreads
105+
using ScopedValues
105106

106107
using TensorKitSectors
107108
import TensorKitSectors: dim, BraidingStyle, FusionStyle, ,
@@ -185,6 +186,7 @@ include("spaces/vectorspaces.jl")
185186
#-------------------------------------
186187
# general definitions
187188
include("tensors/abstracttensor.jl")
189+
include("tensors/backends.jl")
188190
include("tensors/blockiterator.jl")
189191
include("tensors/tensor.jl")
190192
include("tensors/adjoint.jl")

‎src/tensors/backends.jl

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Scheduler implementation
2+
# ------------------------
3+
const blockscheduler = ScopedValue{Scheduler}(SerialScheduler())
4+
const subblockscheduler = ScopedValue{Scheduler}(SerialScheduler())
5+
6+
# Backend implementation
7+
# ----------------------
8+
# TODO: figure out a name
9+
# TODO: what should be the default scheduler?
10+
@kwdef struct TensorKitBackend{B<:AbstractBackend,BS,SBS} <: AbstractBackend
11+
arraybackend::B = TO.DefaultBackend()
12+
blockscheduler::BS = blockscheduler[]
13+
subblockscheduler::SBS = subblockscheduler[]
14+
end
15+
16+
function TO.select_backend(::typeof(TO.tensoradd!), C::AbstractTensorMap,
17+
A::AbstractTensorMap)
18+
return TensorKitBackend()
19+
end
20+
function TO.select_backend(::typeof(TO.tensortrace!), C::AbstractTensorMap,
21+
A::AbstractTensorMap)
22+
return TensorKitBackend()
23+
end
24+
function TO.select_backend(::typeof(TO.tensorcontract!), C::AbstractTensorMap,
25+
A::AbstractTensorMap, B::AbstractTensorMap)
26+
return TensorKitBackend()
27+
end

‎src/tensors/tensoroperations.jl

-24
Original file line numberDiff line numberDiff line change
@@ -144,30 +144,6 @@ TO.tensorcost(t::AbstractTensorMap, i::Int) = dim(space(t, i))
144144
# IMPLEMENTATONS
145145
#----------------
146146

147-
# Backend implementation
148-
# ----------------------
149-
# TODO: figure out a name
150-
# TODO: what should be the default scheduler?
151-
# TODO: should we allow a separate scheduler for "blocks" and "subblocks"
152-
@kwdef struct TensorKitBackend{B<:AbstractBackend,BS,SBS} <: AbstractBackend
153-
arraybackend::B = TO.DefaultBackend()
154-
blockscheduler::BS = SerialScheduler()
155-
subblockscheduler::SBS = SerialScheduler()
156-
end
157-
158-
function TO.select_backend(::typeof(TO.tensoradd!), C::AbstractTensorMap,
159-
A::AbstractTensorMap)
160-
return TensorKitBackend()
161-
end
162-
function TO.select_backend(::typeof(TO.tensortrace!), C::AbstractTensorMap,
163-
A::AbstractTensorMap)
164-
return TensorKitBackend()
165-
end
166-
function TO.select_backend(::typeof(TO.tensorcontract!), C::AbstractTensorMap,
167-
A::AbstractTensorMap, B::AbstractTensorMap)
168-
return TensorKitBackend()
169-
end
170-
171147
# Trace implementation
172148
#----------------------
173149
"""

0 commit comments

Comments
 (0)