diff --git a/IPA.Loader/Utilities/Async/UnityMainThreadTaskScheduler.cs b/IPA.Loader/Utilities/Async/UnityMainThreadTaskScheduler.cs
index f3ae0b52..3bbc84fc 100644
--- a/IPA.Loader/Utilities/Async/UnityMainThreadTaskScheduler.cs
+++ b/IPA.Loader/Utilities/Async/UnityMainThreadTaskScheduler.cs
@@ -4,6 +4,7 @@ using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
+using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
@@ -26,24 +27,25 @@ namespace IPA.Utilities.Async
/// a factory for creating tasks on the default scheduler
public static TaskFactory Factory { get; } = new TaskFactory(Default);
- private readonly ConcurrentDictionary tasks = new ConcurrentDictionary();
- private int queueEndPosition = 0;
- private int queuePosition = 0;
+ private readonly ConcurrentQueue tasks = new ConcurrentQueue();
+ private static readonly ConditionalWeakTable itemTable = new ConditionalWeakTable();
- private struct QueueItem : IEquatable, IEquatable, IEquatable
+ private class QueueItem : IEquatable, IEquatable
{
- public int Index;
- public Task Task;
+ public bool HasTask;
+ private readonly WeakReference weakTask = null;
+ public Task Task => weakTask.TryGetTarget(out var task) ? task : null;
- public QueueItem(int index, Task task) : this()
+ public QueueItem(Task task)
{
- Index = index;
- Task = task;
+ HasTask = true;
+ weakTask = new WeakReference(task);
}
- public bool Equals(int other) => Index.Equals(other);
- public bool Equals(Task other) => Task.Equals(other);
- public bool Equals(QueueItem other) => other.Index == Index || other.Task == Task;
+ private bool Equals(WeakReference task)
+ => weakTask.TryGetTarget(out var t1) && task.TryGetTarget(out var t2) && t1.Equals(t2);
+ public bool Equals(Task other) => HasTask && weakTask.TryGetTarget(out var task) && other.Equals(task);
+ public bool Equals(QueueItem other) => other.HasTask == HasTask && Equals(other.weakTask);
}
///
@@ -126,18 +128,20 @@ namespace IPA.Utilities.Async
{
while (!Cancelling)
{
- if (queuePosition < queueEndPosition)
+ if (!tasks.IsEmpty)
{
var yieldAfter = YieldAfterTasks;
sw.Start();
- for (int i = 0; i < yieldAfter && queuePosition < queueEndPosition
+ for (int i = 0; i < yieldAfter && !tasks.IsEmpty
&& sw.Elapsed < YieldAfterTime; i++)
{
- if (tasks.TryRemove(new QueueItem { Index = Interlocked.Increment(ref queuePosition) }, out var task))
- TryExecuteTask(task); // we succesfully removed the task
- else
- i++; // we didn't
+ QueueItem task;
+ do if (!tasks.TryDequeue(out task)) goto exit; // try dequeue, if we can't exit
+ while (!task.HasTask); // if the dequeued task is empty, try again
+
+ TryExecuteTask(task.Task);
}
+ exit:
sw.Reset();
}
yield return null;
@@ -185,7 +189,9 @@ namespace IPA.Utilities.Async
{
ThrowIfDisposed();
- tasks.TryAdd(new QueueItem(Interlocked.Increment(ref queueEndPosition), task), task);
+ var item = new QueueItem(task);
+ itemTable.Add(task, item);
+ tasks.Enqueue(item);
}
///
@@ -206,8 +212,14 @@ namespace IPA.Utilities.Async
if (!UnityGame.OnMainThread) return false;
if (taskWasPreviouslyQueued)
- if (!tasks.TryRemove(new QueueItem { Task = task }, out var _))
- return false; // if we couldn't remove it, its not in our queue, so it already ran
+ {
+ if (itemTable.TryGetValue(task, out var item))
+ {
+ if (!item.HasTask) return false;
+ item.HasTask = false;
+ }
+ else return false; // if we couldn't remove it, its not in our queue, so it already ran
+ }
return TryExecuteTask(task);
}