Browse Source

Fixed UnityMainThreadTaskScheduler to use ConditionalWeakTable

pull/46/head
Anairkoen Schno 4 years ago
parent
commit
d1d7de0f71
1 changed files with 33 additions and 21 deletions
  1. +33
    -21
      IPA.Loader/Utilities/Async/UnityMainThreadTaskScheduler.cs

+ 33
- 21
IPA.Loader/Utilities/Async/UnityMainThreadTaskScheduler.cs View File

@ -4,6 +4,7 @@ using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
@ -26,24 +27,25 @@ namespace IPA.Utilities.Async
/// <value>a factory for creating tasks on the default scheduler</value>
public static TaskFactory Factory { get; } = new TaskFactory(Default);
private readonly ConcurrentDictionary<QueueItem, Task> tasks = new ConcurrentDictionary<QueueItem, Task>();
private int queueEndPosition = 0;
private int queuePosition = 0;
private readonly ConcurrentQueue<QueueItem> tasks = new ConcurrentQueue<QueueItem>();
private static readonly ConditionalWeakTable<Task, QueueItem> itemTable = new ConditionalWeakTable<Task, QueueItem>();
private struct QueueItem : IEquatable<int>, IEquatable<Task>, IEquatable<QueueItem>
private class QueueItem : IEquatable<Task>, IEquatable<QueueItem>
{
public int Index;
public Task Task;
public bool HasTask;
private readonly WeakReference<Task> weakTask = null;
public Task Task => weakTask.TryGetTarget(out var task) ? task : null;
public QueueItem(int index, Task task) : this()
public QueueItem(Task task)
{
Index = index;
Task = task;
HasTask = true;
weakTask = new WeakReference<Task>(task);
}
public bool Equals(int other) => Index.Equals(other);
public bool Equals(Task other) => Task.Equals(other);
public bool Equals(QueueItem other) => other.Index == Index || other.Task == Task;
private bool Equals(WeakReference<Task> task)
=> weakTask.TryGetTarget(out var t1) && task.TryGetTarget(out var t2) && t1.Equals(t2);
public bool Equals(Task other) => HasTask && weakTask.TryGetTarget(out var task) && other.Equals(task);
public bool Equals(QueueItem other) => other.HasTask == HasTask && Equals(other.weakTask);
}
/// <summary>
@ -126,18 +128,20 @@ namespace IPA.Utilities.Async
{
while (!Cancelling)
{
if (queuePosition < queueEndPosition)
if (!tasks.IsEmpty)
{
var yieldAfter = YieldAfterTasks;
sw.Start();
for (int i = 0; i < yieldAfter && queuePosition < queueEndPosition
for (int i = 0; i < yieldAfter && !tasks.IsEmpty
&& sw.Elapsed < YieldAfterTime; i++)
{
if (tasks.TryRemove(new QueueItem { Index = Interlocked.Increment(ref queuePosition) }, out var task))
TryExecuteTask(task); // we succesfully removed the task
else
i++; // we didn't
QueueItem task;
do if (!tasks.TryDequeue(out task)) goto exit; // try dequeue, if we can't exit
while (!task.HasTask); // if the dequeued task is empty, try again
TryExecuteTask(task.Task);
}
exit:
sw.Reset();
}
yield return null;
@ -185,7 +189,9 @@ namespace IPA.Utilities.Async
{
ThrowIfDisposed();
tasks.TryAdd(new QueueItem(Interlocked.Increment(ref queueEndPosition), task), task);
var item = new QueueItem(task);
itemTable.Add(task, item);
tasks.Enqueue(item);
}
/// <summary>
@ -206,8 +212,14 @@ namespace IPA.Utilities.Async
if (!UnityGame.OnMainThread) return false;
if (taskWasPreviouslyQueued)
if (!tasks.TryRemove(new QueueItem { Task = task }, out var _))
return false; // if we couldn't remove it, its not in our queue, so it already ran
{
if (itemTable.TryGetValue(task, out var item))
{
if (!item.HasTask) return false;
item.HasTask = false;
}
else return false; // if we couldn't remove it, its not in our queue, so it already ran
}
return TryExecuteTask(task);
}


Loading…
Cancel
Save