Skip to content

Commit

Permalink
Switch most interop code in Vezel.Ruptura.Injection to Vezel.Ruptura.…
Browse files Browse the repository at this point in the history
…System APIs.

Part of #30.
  • Loading branch information
alexrp committed Jul 15, 2022
1 parent 81c0096 commit 180dfb8
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 175 deletions.
107 changes: 54 additions & 53 deletions src/injection/AssemblyInjector.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
using Vezel.Ruptura.Injection.IO;
using Vezel.Ruptura.Injection.Threading;
using Windows.Win32.Foundation;
using Windows.Win32.System.Memory;
using static Iced.Intel.AssemblerRegisters;
using Win32 = Windows.Win32.WindowsPInvoke;

namespace Vezel.Ruptura.Injection;

Expand Down Expand Up @@ -31,13 +28,15 @@ unsafe struct RupturaParameters

bool _injecting;

bool _waiting;

nuint _loadLibraryW;

nuint _getProcAddress;

nuint _getLastError;

SafeHandle? _threadHandle;
ThreadObject? _thread;

public AssemblyInjector(TargetProcess process, AssemblyInjectorOptions options)
{
Expand All @@ -63,7 +62,7 @@ public void Dispose()

void DisposeCore()
{
_threadHandle?.Dispose();
_thread?.Dispose();
}

string GetModulePath()
Expand All @@ -74,9 +73,9 @@ string GetModulePath()
return File.Exists(path) ? path : throw new InjectionException("Could not locate the Ruptura native module.");
}

void PopulateMemoryArea(nuint area, nuint length, Action<ProcessMemoryStream, InjectionBinaryWriter> action)
void PopulateMemoryArea(nuint area, nint length, Action<ProcessMemoryStream, InjectionBinaryWriter> action)
{
using var stream = new ProcessMemoryStream(_process, area, length);
using var stream = new ProcessMemoryStream(_process.Object, area, length);
using var writer = new InjectionBinaryWriter(stream, true);

action(stream, writer);
Expand All @@ -102,23 +101,20 @@ void ForceLoaderInitialization()
{
// Spawning a live thread in a process that was created suspended forces the Windows image loader to finish
// loading the image so that, among other things, we will be able to resolve kernel32.dll exports.
using var threadHandle = _process.CreateThread(initializeShell, 0);
using var thread = _process.CreateThread(initializeShell, 0);

switch ((WIN32_ERROR)Win32.WaitForSingleObjectEx(
threadHandle, (uint)(long)_options.InjectionTimeout.TotalMilliseconds, false))
// TODO: Consider making this async with ThreadPool.UnsafeRegisterWaitForSingleObject().
switch (thread.Wait(_options.InjectionTimeout, false))
{
case WIN32_ERROR.WAIT_OBJECT_0:
case WaitResult.Signaled:
break;
case WIN32_ERROR.WAIT_TIMEOUT:
case WaitResult.TimedOut:
throw new TimeoutException();
default:
throw new Win32Exception();
throw new UnreachableException();
}

if (!Win32.GetExitCodeThread(threadHandle, out var code))
throw new Win32Exception();

if (code != 0)
if (thread.GetExitCode() is var code and not 0)
throw new InjectionException($"Failed to initialize the target process: 0x{code:x}");
}
finally
Expand All @@ -132,7 +128,7 @@ void RetrieveKernel32Exports()
if (_process.GetModule("kernel32.dll") is not (var k32Addr, var k32Size))
throw new InjectionException("Could not locate 'kernel32.dll' in the target process.");

using var stream = new ProcessMemoryStream(_process, k32Addr, k32Size);
using var stream = new ProcessMemoryStream(_process.Object, k32Addr, k32Size);

var exports = new PeFile(stream).ExportedFunctions;

Expand All @@ -148,7 +144,7 @@ nuint GetExport(string name)
_getLastError = GetExport("GetLastError");
}

unsafe (nuint Address, nuint Length) CreateParameterArea()
unsafe (nuint Address, nint Length) CreateParameterArea()
{
// Keep in sync with src/module/main.h.

Expand All @@ -159,12 +155,10 @@ nuint GetExport(string name)
foreach (var arg in _options.Arguments.Prepend(_options.FileName))
length += Encoding.Unicode.GetByteCount(arg) + sizeof(char);

var len = (nuint)length;

return (_process.AllocMemory((nuint)length, PAGE_PROTECTION_FLAGS.PAGE_READWRITE), len);
return (_process.AllocateMemory(length, MemoryAccess.ReadWrite), length);
}

unsafe void PopulateParameterArea(nuint address, nuint length)
unsafe void PopulateParameterArea(nuint address, nint length)
{
// Keep in sync with src/module/main.h.

Expand Down Expand Up @@ -213,14 +207,14 @@ async Task InjectModuleAsync(string modulePath, nuint parameterArea, MemoryMappe
{
var nameAreaLength = Encoding.Unicode.GetByteCount(modulePath) + sizeof(char) +
Encoding.ASCII.GetByteCount(NativeEntryPoint) + sizeof(byte);
var nameArea = _process.AllocMemory((uint)nameAreaLength, PAGE_PROTECTION_FLAGS.PAGE_READWRITE);
var nameArea = _process.AllocateMemory(nameAreaLength, MemoryAccess.ReadWrite);

try
{
nuint modulePathPtr = 0;
nuint entryPointPtr = 0;

PopulateMemoryArea(nameArea, (uint)nameAreaLength, (stream, writer) =>
PopulateMemoryArea(nameArea, nameAreaLength, (stream, writer) =>
{
modulePathPtr = nameArea + (nuint)stream.Position;
Expand Down Expand Up @@ -300,7 +294,7 @@ async Task InjectModuleAsync(string modulePath, nuint parameterArea, MemoryMappe

try
{
var threadHandle = _process.CreateThread(injectShell, parameterArea);
var thread = _process.CreateThread(injectShell, parameterArea);

try
{
Expand All @@ -312,24 +306,22 @@ async Task InjectModuleAsync(string modulePath, nuint parameterArea, MemoryMappe
// Did injection complete successfully?
if (accessor.ReadBoolean(0))
{
_threadHandle = threadHandle;
_thread = thread;

break;
}

// Did the thread exit with an error?
switch ((WIN32_ERROR)Win32.WaitForSingleObjectEx(threadHandle, 0, false))
switch (thread.Wait(TimeSpan.Zero, false))
{
case WIN32_ERROR.WAIT_OBJECT_0:
if (!Win32.GetExitCodeThread(threadHandle, out var code))
throw new Win32Exception();

case WaitResult.Signaled:
throw new InjectionException(
$"Failed to inject the native module into the target process: 0x{code:x}");
case WIN32_ERROR.WAIT_TIMEOUT:
$"Failed to inject the native module into the target process: " +
$"0x{thread.GetExitCode():x}");
case WaitResult.TimedOut:
break;
default:
throw new Win32Exception();
throw new UnreachableException();
}

await Task.Delay(100);
Expand All @@ -340,12 +332,12 @@ async Task InjectModuleAsync(string modulePath, nuint parameterArea, MemoryMappe
}
catch (Exception)
{
threadHandle.Dispose();
thread.Dispose();

throw;
}

_threadHandle = threadHandle;
_thread = thread;
}
finally
{
Expand Down Expand Up @@ -399,32 +391,41 @@ public Task InjectAssemblyAsync()

public Task<int> WaitForCompletionAsync()
{
_ = _threadHandle is not null and { IsInvalid: false } ? true : throw new InvalidOperationException();
_ = _thread != null && !_waiting ? true : throw new InvalidOperationException();

_waiting = true;

return Task.Run(async () =>
{
using var waitHandle = new ThreadWaitHandle(new(_threadHandle.DangerousGetHandle(), true));
// Transfer ownership of the native handle from _threadHandle to waitHandle so that it stays alive until the
// injected assembly returns, allowing us to retrieve the exit code.
_threadHandle.SetHandleAsInvalid();
// This is safe because the lambda below captures the thread object and keeps it alive.
using var waitHandle = new ThreadWaitHandle(new(_thread.SafeHandle.DangerousGetHandle(), false));
var tcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
var registration = ThreadPool.UnsafeRegisterWaitForSingleObject(
waitHandle,
(_, timeout) =>
{
var ex = default(Exception);
if (timeout)
ex = new TimeoutException();
else if (!Win32.GetExitCodeThread(waitHandle.SafeWaitHandle, out var code))
ex = new Win32Exception();
else
tcs.SetResult((int)code);
if (ex != null)
tcs.SetException(ExceptionDispatchInfo.SetCurrentStackTrace(ex));
{
tcs.SetException(ExceptionDispatchInfo.SetCurrentStackTrace(new TimeoutException()));
return;
}
int code;
try
{
code = _thread.GetExitCode();
}
catch (Win32Exception ex)
{
tcs.SetException(ex);
return;
}
tcs.SetResult(code);
},
null,
_options.CompletionTimeout,
Expand Down
42 changes: 22 additions & 20 deletions src/injection/IO/ProcessMemoryStream.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
namespace Vezel.Ruptura.Injection.IO;

sealed class ProcessMemoryStream : Stream
sealed unsafe class ProcessMemoryStream : Stream
{
// TODO: Review some of the casts here.

Expand All @@ -10,39 +10,39 @@ sealed class ProcessMemoryStream : Stream

public override bool CanSeek => true;

public override long Length => (nint)_length;
public override long Length => _length;

public override long Position
{
get => (nint)_position;
get => _position;
set
{
_ = value >= 0 ? true : throw new ArgumentOutOfRangeException(nameof(value));

_position = (nuint)value;
_position = (nint)value;
}
}

readonly TargetProcess _process;
readonly ProcessObject _process;

readonly nuint _address;
readonly void* _address;

readonly nuint _length;
readonly nint _length;

nuint _position;
nint _position;

bool _wrote;

public ProcessMemoryStream(TargetProcess process, nuint address, nuint length)
public ProcessMemoryStream(ProcessObject process, nuint address, nint length)
{
_process = process;
_address = address;
_address = (void*)address;
_length = length;
}

public override long Seek(long offset, SeekOrigin origin)
{
var off = (nuint)offset;
var off = (nint)offset;

switch (origin)
{
Expand All @@ -68,13 +68,13 @@ public override long Seek(long offset, SeekOrigin origin)
throw new ArgumentOutOfRangeException(nameof(origin));
}

return (nint)_position;
return _position;
}

public override void Flush()
{
if (_wrote)
_process.FlushCache(_address, _length);
_process.FlushInstructionCache(_address, _length);
}

public override void SetLength(long value)
Expand Down Expand Up @@ -102,21 +102,22 @@ public override Task<int> ReadAsync(

public override int Read(Span<byte> buffer)
{
var len = (int)nuint.Min(_length - _position, (uint)buffer.Length);
var len = (int)nint.Min(_length - _position, buffer.Length);

if (len <= 0)
return 0;

try
{
_process.ReadMemory(_address + _position, buffer[..len]);
fixed (byte* p = buffer)
_process.ReadMemory((byte*)_address + (nuint)_position, p, len);
}
catch (Win32Exception ex)
{
throw new IOException(null, ex);
}

_position += (uint)len;
_position += len;

return len;
}
Expand All @@ -126,7 +127,7 @@ public override ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken
return ValueTask.FromResult(Read(buffer.Span));
}

public override unsafe int ReadByte()
public override int ReadByte()
{
byte value;

Expand All @@ -151,20 +152,21 @@ public override Task WriteAsync(

public override void Write(ReadOnlySpan<byte> buffer)
{
_ = _position + (uint)buffer.Length <= _length ? true : throw new NotSupportedException();
_ = _position + buffer.Length <= _length ? true : throw new NotSupportedException();

_wrote = true;

try
{
_process.WriteMemory(_address + _position, buffer);
fixed (byte* p = buffer)
_process.WriteMemory((byte*)_address + (nuint)_position, p, buffer.Length);
}
catch (Win32Exception ex)
{
throw new IOException(null, ex);
}

_position += (uint)buffer.Length;
_position += buffer.Length;
}

public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
Expand Down
11 changes: 0 additions & 11 deletions src/injection/NativeMethods.txt
Original file line number Diff line number Diff line change
@@ -1,19 +1,8 @@
CreateProcessW
CreateRemoteThreadEx
CreateToolhelp32Snapshot
FlushInstructionCache
GetExitCodeThread
IsWow64Process2
K32GetModuleBaseNameW
Module32FirstW
Module32NextW
OpenProcess
ReadProcessMemory
VirtualAlloc2
VirtualFreeEx
VirtualProtectEx
WaitForSingleObjectEx
WriteProcessMemory

WIN32_ERROR

Expand Down
Loading

0 comments on commit 180dfb8

Please sign in to comment.