//////////////////////////////////////////////////////////////
//                                                          //
//  Direct3D 9 Hook.cpp                                     //
//  Version 1.0                                             //
//  Greg Jenkins, November 2007                             //
//  Ring3 Circus (www.ring3circus.com)                      //
//  Creative Commons Attribution 3.0 Unported License       //
//                                                          //
//  Sample DLL to demonstrate hooking of the Present        //
//  methods of IDirect3D9 to draw a user overlay onto       //
//  arbitrary DirectX 9 programs.                           //
//                                                          //
//////////////////////////////////////////////////////////////

#include <string>
#include <Windows.h>
#include <d3d9.h>
#include <d3dx9.h>

template<class COMObject>
void SafeRelease(COMObject*& pRes)
{
    IUnknown *unknown = pRes;
    if (unknown)
    {
        unknown->Release();
    }
    pRes = NULL;
}

#pragma comment(lib, "d3dx9.lib")

const long OFFSET_DEVICE_PRESENT = 0x00040EA0;
const long OFFSET_SWAP_CHAIN_PRESENT = 0x00039230;
const long OFFSET_CLEAR = 0x00085720;

unsigned char EXPECTED_OPCODES_DEVICE_PRESENT[5] = {0x8B, 0xFF, 0x55, 0x8B, 0xEC};
unsigned char EXPECTED_OPCODES_SWAP_CHAIN_PRESENT[5] = {0x8B, 0xFF, 0x55, 0x8B, 0xEC};
unsigned char EXPECTED_OPCODES_CLEAR[5] = {0x8B, 0xFF, 0x55, 0x8B, 0xEC};

void* address_DevicePresent;
void* address_SwapChainPresent;
void* address_Clear;

unsigned char backup_DevicePresent[5];
unsigned char patch_DevicePresent[5];
unsigned char backup_SwapChainPresent[5];
unsigned char patch_SwapChainPresent[5];
unsigned char backup_Clear[5];
unsigned char patch_Clear[5];

char path_d3d9_dll[MAX_PATH];
bool patches_created = false;
bool resources_created = false;

void WriteHook(void* address, unsigned char* patch);

ID3DXFont* font_object = NULL;
HMODULE module_self = NULL;

void CreateResourcesD3D9(IDirect3DDevice9* device) {
	if SUCCEEDED(D3DXCreateFont(device, 
							 40,			// Height
							 0,				// Width (0 default)
							 FW_NORMAL,		// Weight
							 1,				// MipLevels
							 FALSE,			// Italic
							 DEFAULT_CHARSET,
							 OUT_DEFAULT_PRECIS,
							 DEFAULT_QUALITY,
							 DEFAULT_PITCH,
							 "Franklin Gothic Medium",
							 &font_object)) {
		resources_created = true;
	}
}

void PrePresent(IDirect3DDevice9* device) {
	// This will get called before Present

	std::string message("Direct3D 9 hook up and running.");

	if (font_object != NULL) {
		RECT target_rect;
		target_rect.left = 50; target_rect.top = 50; target_rect.bottom = 100; target_rect.right = 800;
		font_object->DrawTextA(NULL, message.c_str(), static_cast<int> (message.size()), &target_rect, 0, 0xFF404040);
		target_rect.left += 2; target_rect.top -= 2;
		font_object->DrawTextA(NULL, message.c_str(), static_cast<int> (message.size()), &target_rect, 0, 0xFFE0FFA0);
	}
}

void PostPresent(IDirect3DDevice9* device) {
	// This will get called after Present
}

void PreClear(IDirect3DDevice9* device) {
	// This will get called before Clear

	if (!resources_created) CreateResourcesD3D9(device);
}

void PostClear(IDirect3DDevice9* device) {
	// This will get called after Clear
}

HRESULT __stdcall SwapChainPresentHook(IDirect3DSwapChain9* swap_chain, const RECT* pSourceRect, const RECT* pDestRect, HWND hDestWindowOverride, const RGNDATA* pDirtyRegion, DWORD dwFlags) {
	// The hook function for IDirect3DSwapChain9::Present

	IDirect3DDevice9* device = NULL;
	if FAILED(swap_chain->GetDevice(&device)) {
		// This is pretty unlikely, but it doesn't pay to gamble
		return swap_chain->Present(pSourceRect, pDestRect, hDestWindowOverride, pDirtyRegion, dwFlags);
	}

	PrePresent(device);
	WriteHook(address_SwapChainPresent, backup_SwapChainPresent);
    HRESULT return_value = swap_chain->Present(pSourceRect, pDestRect, hDestWindowOverride, pDirtyRegion, dwFlags);
	WriteHook(address_SwapChainPresent, patch_SwapChainPresent);
    PostPresent(device);

    return return_value;
}

HRESULT __stdcall DevicePresentHook(IDirect3DDevice9* device, const RECT* pSourceRect, const RECT* pDestRect, HWND hDestWindowOverride, const RGNDATA* pDirtyRegion) {
	// The hook function for IDirect3DDevice9::Present

	PrePresent(device);
	WriteHook(address_DevicePresent, backup_DevicePresent);
    HRESULT return_value = device->Present(pSourceRect, pDestRect, hDestWindowOverride, pDirtyRegion);
	WriteHook(address_DevicePresent, patch_DevicePresent);
	PostPresent(device);

    return return_value;
}

HRESULT __stdcall ClearHook(IDirect3DDevice9* device, DWORD Count, CONST D3DRECT * pRects, DWORD Flags, D3DCOLOR Color, float Z, DWORD Stencil) {
	// The hook function for IDirect3DDevice9::Clear

	PreClear(device);
    WriteHook(address_Clear, backup_Clear);
    HRESULT return_value = device->Clear(Count, pRects, Flags, Color, Z, Stencil);
	WriteHook(address_Clear, patch_Clear);
	PostClear(device);
    
    return return_value;
}

void CreatePath() {
	// We only need to do this once, so let's separate it from Initialise

	GetSystemDirectory(path_d3d9_dll, MAX_PATH);	
	strncat_s(path_d3d9_dll, MAX_PATH, "\\d3d9.dll", 10);
}

void WriteHook(void* address, unsigned char* patch) {
	// This will write the five-byte buffer at 'patch' to 'address',
	// after making sure that it's safe to write there

	// Set access
	DWORD old_protect = 0;
	if (VirtualProtect(address, 5, PAGE_EXECUTE_READWRITE, &old_protect) == FALSE) return;
	if (IsBadWritePtr(address, 5) != FALSE) return;

	memcpy(address, reinterpret_cast<void*> (patch), 5);
}

bool VerifyAddresses() {
	// We can't afford to go overwriting arbitrary memory addresses without being
	// absolutely sure that they are indeed the right functions

	if (memcmp(backup_DevicePresent, EXPECTED_OPCODES_DEVICE_PRESENT, 5) != 0) return false;
	if (memcmp(backup_SwapChainPresent, EXPECTED_OPCODES_SWAP_CHAIN_PRESENT, 5) != 0) return false;
	if (memcmp(backup_Clear, EXPECTED_OPCODES_CLEAR, 5) != 0) return false;

	return true;
}

DWORD WINAPI Initialise(__in  LPVOID lpParameter) {
	// We define this as a ThreadProc so that it may be used with CreateRemoteThread
	// This function may be called many times, but must succeed at least once for the hooks to work
	
	// Look for a suitable DLL
	char* address_d3d9 = reinterpret_cast<char*> (GetModuleHandleA(path_d3d9_dll));
	if (address_d3d9 == NULL) return 0;

	// Calculate addresses
	address_DevicePresent = reinterpret_cast<void*> (address_d3d9 + OFFSET_DEVICE_PRESENT);
	address_SwapChainPresent = reinterpret_cast<void*> (address_d3d9 + OFFSET_SWAP_CHAIN_PRESENT);
	address_Clear = reinterpret_cast<void*> (address_d3d9 + OFFSET_CLEAR);

	// Create backups
	if (!patches_created) {
		DWORD old_protect = 0;
		
		if (VirtualProtect(address_DevicePresent, 5, PAGE_EXECUTE_READWRITE, &old_protect) == FALSE) return 0;
		memcpy(backup_DevicePresent, address_DevicePresent, 5);

		if (VirtualProtect(address_SwapChainPresent, 5, PAGE_EXECUTE_READWRITE, &old_protect) == FALSE) return 0;
		memcpy(backup_SwapChainPresent, address_SwapChainPresent, 5);

		if (VirtualProtect(address_Clear, 5, PAGE_EXECUTE_READWRITE, &old_protect) == FALSE) return 0;
		memcpy(backup_Clear, address_Clear, 5);

		if (!VerifyAddresses()) return 0;

		// We need the DLL to add a reference to itself here to prevent it being unloaded
		// before the process terminates, as there is no way to guarantee that unloading won't
		// occur at an inopportune moment and send everything down the pan.
		// This means that the DLL can't ever be unloaded, but that's the price we pay.
		char file_name_self[MAX_PATH];
		GetModuleFileName(module_self, file_name_self, MAX_PATH);
		LoadLibrary(file_name_self);

		// Create patches
		{
			DWORD from_int, to_int, offset;

			// Device::Present
			from_int = reinterpret_cast<DWORD> (address_DevicePresent);
			to_int = reinterpret_cast<DWORD> (&DevicePresentHook);
			offset = to_int - from_int - 5;

			patch_DevicePresent[0] = 0xE9; // Rel32 JMP
			*(reinterpret_cast<DWORD*> (patch_DevicePresent + 1)) = offset;

			// SwapChain::Present
			from_int = reinterpret_cast<DWORD> (address_SwapChainPresent);
			to_int = reinterpret_cast<DWORD> (&SwapChainPresentHook);
			offset = to_int - from_int - 5;

			patch_SwapChainPresent[0] = 0xE9; // Rel32 JMP
			*(reinterpret_cast<DWORD*> (patch_SwapChainPresent + 1)) = offset;

			// Clear
			from_int = reinterpret_cast<DWORD> (address_Clear);
			to_int = reinterpret_cast<DWORD> (&ClearHook);
			offset = to_int - from_int - 5;

			patch_Clear[0] = 0xE9; // Rel32 JMP
			*(reinterpret_cast<DWORD*> (patch_Clear + 1)) = offset;
		}
		
		patches_created = true;
	}

	// Install hooks
	WriteHook(address_DevicePresent, patch_DevicePresent);
	WriteHook(address_SwapChainPresent, patch_SwapChainPresent);
	WriteHook(address_Clear, patch_Clear);

	return -1;
}

DWORD WINAPI Release(void* parameter) {
	// The parameter is unused, but necessary for compatibility with CreateRemoteThread

	// Release resources
	if (font_object != NULL) SafeRelease(font_object);

	// Remove all patches
	WriteHook(address_DevicePresent, backup_DevicePresent);
	WriteHook(address_SwapChainPresent, backup_SwapChainPresent);
	WriteHook(address_Clear, backup_Clear);

	// Reset access
	DWORD old_protect = 0;
	DWORD new_protect = PAGE_EXECUTE_READ;
	VirtualProtect(address_DevicePresent, 5, new_protect, &old_protect);
	VirtualProtect(address_SwapChainPresent, 5, new_protect, &old_protect);
	VirtualProtect(address_Clear, 5, new_protect, &old_protect);

	return 0;
}

BOOL APIENTRY DllMain(HINSTANCE hModule, DWORD dwReason, PVOID lpReserved)
{	
	switch (dwReason) 
	{
    case DLL_PROCESS_ATTACH:
		module_self = hModule;
		CreatePath();
		break;
	case DLL_PROCESS_DETACH:
		Release(NULL);
		break;		
	}
	return TRUE;
}