-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathalgorithm_firescarmapper.py
More file actions
436 lines (352 loc) · 19.6 KB
/
algorithm_firescarmapper.py
File metadata and controls
436 lines (352 loc) · 19.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
# -*- coding: utf-8 -*-
"""
/***************************************************************************
FireScarMapper
A QGIS plugin
Generate georeferenced fire scar rasters using a pre-trained U-Net model and analyze the impact of fire events by comparing pre- and post-fire satellite images.
Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
-------------------
begin : 2024-11-25
git sha : $Format:%H$
copyright : (C) 2024 by Fire 2A
email : N/A
***************************************************************************/
/***************************************************************************
* *
* This program is free software; you can redistribute it and/or modify *
* it under the terms of the GNU General Public License as published by *
* the Free Software Foundation; either version 2 of the License, or *
* (at your option) any later version. *
* *
***************************************************************************/
"""
# Initialize Qt resources from file resources.py
# Import the code for the dialog
from .resources import *
import os.path
from qgis.core import (Qgis, QgsProcessingAlgorithm,QgsProject, QgsRasterLayer, QgsProcessingException,
QgsSingleBandPseudoColorRenderer, QgsRasterMinMaxOrigin, QgsColorRampShader, QgsRasterShader, QgsStyle, QgsContrastEnhancement, QgsMultiBandColorRenderer)
from qgis.PyQt.QtCore import QCoreApplication
import torch
from .firescarmapping.model_u_net import model, device
from .firescarmapping.as_dataset import create_datasetAS
from .firescarmapping.dataset_128 import create_dataset128
import numpy as np
from torch.utils.data import DataLoader
import os
from osgeo import gdal, gdal_array
import requests
class ProcessingAlgorithm(QgsProcessingAlgorithm):
def main(self, parameters, context, feedback):
before_paths, burnt_paths, datatype = parameters['BeforeRasters'], parameters['AfterRasters'], parameters['ModelScale']
not_cropped_paths = before_paths + burnt_paths
before, burnt = [], []
results_dir = os.path.join(os.path.dirname(__file__), 'results')
model_scale_dir = os.path.join(results_dir, datatype)
# Crear el directorio 'results' si no existe
if not os.path.exists(results_dir):
os.makedirs(results_dir)
feedback.pushInfo(f"Created main results directory at: {results_dir}")
if not os.path.exists(model_scale_dir):
os.makedirs(model_scale_dir)
feedback.pushInfo(f"Created directory for specified model at: {model_scale_dir}")
#ImgPosF_CL-BI_ID74101_u350_19980330_clip
for i in range(len(before_paths)):
before_name = parameters['BeforeRasters'][i].split("/")[-1]
burnt_name = parameters['AfterRasters'][i].split("/")[-1]
before.append(QgsRasterLayer(parameters['BeforeRasters'][i], before_name, "gdal"))
burnt.append(QgsRasterLayer(parameters['AfterRasters'][i], burnt_name, "gdal"))
# Asegurarse de que las capas sean listas de QgsRasterLayer
if not isinstance(before, list) or not isinstance(burnt, list):
raise QgsProcessingException("Input rasters must be lists of QgsRasterLayer")
if len(before) != len(burnt):
raise QgsProcessingException("The number of before and burnt rasters must be the same")
rasters = []
for i, layer in enumerate(before + burnt):
basename = os.path.splitext(os.path.basename(layer.source()))[0]
feedback.pushInfo(f"layer.id(): {layer.id()}")
feedback.pushInfo(f"layer.name(): {layer.name()}")
feedback.pushInfo(f"layer.name() 2: {os.path.splitext(layer.name())[0]}")
adict = {
"type": "before" if i < len(before) else "burnt",
"id": i,
"qid": layer.id(),
"name": os.path.splitext(layer.name())[0],
"data": self.get_rlayer_data(layer),
"layer": layer,
"path": not_cropped_paths[i],
"not_cropped_path": not_cropped_paths[i],
"output_path": os.path.join(model_scale_dir, f"FireScar_{basename}.tif")
}
adict.update(self.get_rlayer_info(layer))
rasters += [adict]
before_files, after_files, before_files_data, after_files_data = [], [], [], []
#Order rasters
if len(rasters) % 2 != 0:
raise ValueError("El número total de capas debe ser par (pares de imágenes previas y posteriores).")
half = len(rasters) // 2
before_files = rasters[:half]
after_files = rasters[half:]
before_files_data = [r['data'] for r in before_files]
after_files_data = [r['data'] for r in after_files]
if datatype == "AS":
model_path = os.path.join(os.path.dirname(__file__), 'firescarmapping', 'ep25_lr1e-04_bs16_021__as_std_adam_f01_13_07_x3.model')
model_download_url = "https://fire2a-firescar-as-model.s3.amazonaws.com/ep25_lr1e-04_bs16_021__as_std_adam_f01_13_07_x3.model"
else:
model_path = os.path.join(os.path.dirname(__file__), 'firescarmapping', 'ep25_lr1e-04_bs16_014_128_std_25_08_mult3_adam01.model')
model_download_url = "https://fire2a-firescar-as-model.s3.amazonaws.com/ep25_lr1e-04_bs16_014_128_std_25_08_mult3_adam01.model"
if not os.path.exists(model_path):
feedback.pushInfo("Model not found. Initializing download...")
self.download_model(model_path, model_download_url, feedback)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
np.random.seed(3)
torch.manual_seed(3)
if datatype == "128":
data_eval = create_dataset128(before_files_data, after_files_data, mult=1)
else:
data_eval = create_datasetAS(before_files_data, after_files_data, mult=1)
batch_size = 1 # 1 to create diagnostic images, any value otherwise
all_dl = DataLoader(data_eval, batch_size=batch_size)
model.eval()
for i, batch in enumerate(all_dl):
x = batch['img'].float().to(device)
output = model(x).cpu()
# obtain binary prediction map
pred = np.zeros(output.shape)
pred[output >= 0] = 1
generated_matrix = pred[0][0]
if before_files[i]['output_path']:
group_name = f"FireScarGroup_{i+1} ({datatype})"
root = QgsProject.instance().layerTreeRoot()
group = root.findGroup(group_name)
if not group:
group = root.addGroup(group_name)
# Colapsar el grupo para que se muestre minimizado en el panel de capas
project_instance = QgsProject.instance()
layer_tree = project_instance.layerTreeRoot().findGroup(group_name)
if layer_tree:
layer_tree.setExpanded(False)
pre_file_path_name = before_files[i]['name'].split("\\")[-1]
pre_file_name = pre_file_path_name.split("_")[-1]
post_file_path_name = after_files[i]['name'].split("\\")[-1]
post_file_name = post_file_path_name.split("_")[-1]
self.writeRaster(generated_matrix, before_files[i]['output_path'], before_files[i], feedback)
self.addRasterLayer(before_files[i]['output_path'],f"FireScar_{post_file_name}", group, "FireScar", context)
self.addRasterLayer(after_files[i]['not_cropped_path'],f"ImgPosF_{post_file_name}", group, "ImgPosF", context)
self.addRasterLayer(before_files[i]['not_cropped_path'],f"ImgPreF_{pre_file_name}", group, "ImgPreF", context)
return {}
def download_model(self, model_path, download_url, feedback):
"""Download the model from Amazon S3 with progress feedback."""
def save_response_content(response, destination, feedback, total_size):
"""Guardar el contenido descargado en el archivo de destino con retroalimentación de progreso."""
CHUNK_SIZE = 1048576 # 1 MB
bytes_downloaded = 0
with open(destination, "wb") as f:
for chunk in response.iter_content(CHUNK_SIZE):
if chunk: # Filtrar los "keep-alive" chunks vacíos
f.write(chunk)
bytes_downloaded += len(chunk)
# Calcular el porcentaje de descarga completada
progress = (bytes_downloaded / total_size) * 100
# Usar setProgress solo si está disponible en el feedback
if hasattr(feedback, 'setProgress'):
feedback.setProgress(int(progress))
# Informar el progreso en MB
feedback.pushInfo(f"Downloaded {bytes_downloaded // (1024 * 1024)} MB of {total_size // (1024 * 1024)} MB")
# Iniciar una sesión persistente para reutilizar la conexión
session = requests.Session()
try:
# Intentar realizar la solicitud con un timeout y streaming habilitado
response = session.get(download_url, stream=True, timeout=30)
response.raise_for_status() # Lanza una excepción si la descarga falla
# Obtener el tamaño total del archivo desde los encabezados de la respuesta
total_size = int(response.headers.get('Content-Length', 0))
if total_size == 0:
raise requests.exceptions.RequestException("Unable to determine the file size.")
# Informar sobre el inicio de la descarga
feedback.pushInfo(f"Downloading model to {model_path} ({total_size // (1024 * 1024)} MB)")
# Guardar el contenido descargado
save_response_content(response, model_path, feedback, total_size)
# Informar que la descarga ha sido exitosa
feedback.pushInfo(f"Model successfully downloaded and saved at {model_path}")
except requests.exceptions.RequestException as e:
# Manejo de cualquier error que pueda ocurrir durante la solicitud
feedback.pushInfo(f"Failed to download model: {str(e)}")
def qgis2numpy_dtype(self, qgis_dtype: Qgis.DataType) -> np.dtype:
"""Conver QGIS data type to corresponding numpy data type
https://raw.githubusercontent.com/PUTvision/qgis-plugin-deepness/fbc99f02f7f065b2f6157da485bef589f611ea60/src/deepness/processing/processing_utils.py
This is modified and extended copy of GDALDataType.
* `UnknownDataType: Unknown or unspecified type
* `Byte: Eight bit unsigned integer (quint8)
* `Int8: Eight bit signed integer (qint8) (added in QGIS 3.30)
* `UInt16: Sixteen bit unsigned integer (quint16)
* `Int16: Sixteen bit signed integer (qint16)
* `UInt32: Thirty two bit unsigned integer (quint32)
* `Int32: Thirty two bit signed integer (qint32)
* `Float32: Thirty two bit floating point (float)
* `Float64: Sixty four bit floating point (double)
* `CInt16: Complex Int16
* `CInt32: Complex Int32
* `CFloat32: Complex Float32
* `CFloat64: Complex Float64
* `ARGB32: Color, alpha, red, green, blue, 4 bytes the same as QImage.Format_ARGB32
* `ARGB32_Premultiplied: Color, alpha, red, green, blue, 4 bytes the same as QImage.Format_ARGB32_Premultiplied
"""
if qgis_dtype == Qgis.DataType.Byte or qgis_dtype == "Byte":
return np.uint8
if qgis_dtype == Qgis.DataType.UInt16 or qgis_dtype == "UInt16":
return np.uint16
if qgis_dtype == Qgis.DataType.Int16 or qgis_dtype == "Int16":
return np.int16
if qgis_dtype == Qgis.DataType.Float32 or qgis_dtype == "Float32":
return np.float32
if qgis_dtype == Qgis.DataType.Float64 or qgis_dtype == "Float64":
return np.float64
def get_rlayer_info(self, layer: QgsRasterLayer):
"""Get raster layer info: width, height, extent, crs, cellsize_x, cellsize_y, nodata list, number of bands.
Args:
layer (QgsRasterLayer): A raster layer
Returns:
dict: raster layer info
"""
provider = layer.dataProvider()
ndv = []
for band in range(1, layer.bandCount() + 1):
ndv += [None]
if provider.sourceHasNoDataValue(band):
ndv[-1] = provider.sourceNoDataValue(band)
return {
"width": layer.width(),
"height": layer.height(),
"extent": layer.extent(),
"crs": layer.crs(),
"cellsize_x": layer.rasterUnitsPerPixelX(),
"cellsize_y": layer.rasterUnitsPerPixelY(),
"nodata": ndv,
"bands": layer.bandCount(),
}
def get_rlayer_data(self, layer: QgsRasterLayer):
"""Get raster layer data (EVERY BAND) as numpy array; Also returns nodata value, width and height
The user should check the shape of the data to determine if it is a single band or multiband raster.
len(data.shape) == 2 for single band, len(data.shape) == 3 for multiband.
Args:
layer (QgsRasterLayer): A raster layer
Returns:
data (np.array): Raster data as numpy array
nodata (None | list): No data value
width (int): Raster width
height (int): Raster height
FIXME? can a multiband raster have different nodata values and/or data types for each band?
TODO: make a band list as input
"""
provider = layer.dataProvider()
if layer.bandCount() == 1:
block = provider.block(1, layer.extent(), layer.width(), layer.height())
nodata = None
if block.hasNoDataValue():
nodata = block.noDataValue()
np_dtype = self.qgis2numpy_dtype(provider.dataType(1))
data = np.frombuffer(block.data(), dtype=np_dtype).reshape(layer.height(), layer.width())
else:
data = []
nodata = []
np_dtype = []
for i in range(layer.bandCount()):
block = provider.block(i + 1, layer.extent(), layer.width(), layer.height())
nodata += [None]
if block.hasNoDataValue():
nodata[-1] = block.noDataValue()
np_dtype += [self.qgis2numpy_dtype(provider.dataType(i + 1))]
data += [np.frombuffer(block.data(), dtype=np_dtype[-1]).reshape(layer.height(), layer.width())]
# would different data types bug this next line?
data = np.array(data)
# return data, nodata, np_dtype
return data
def writeRaster(self, matrix, file_path, before_layer, feedback):
if np.count_nonzero(matrix) == 0:
raise QgsProcessingException("The generated fire scar matrix is empty. No valid pixels were found.")
# Get the dimensions of the raster before the fire
width = before_layer["width"]
height = before_layer["height"]
# Create the output raster file
driver = gdal.GetDriverByName('GTiff')
raster = driver.Create(file_path, width, height, 1, gdal.GDT_Byte)
if raster is None:
raise QgsProcessingException("Failed to create raster file.")
# Set the geotransformation and projection
extent = before_layer["extent"]
pixel_width = extent.width() / width
pixel_height = extent.height() / height
raster.SetGeoTransform((extent.xMinimum(), pixel_width, 0, extent.yMaximum(), 0, -pixel_height))
raster.SetProjection(before_layer["crs"].toWkt())
# Get the raster band
band = raster.GetRasterBand(1)
# Calculate the offset and size of the burn scar region to fit the raster
start_row = 0
start_col = 0
matrix_height, matrix_width = matrix.shape
if matrix_height > height:
start_row = (matrix_height - height) // 2
matrix_height = height
if matrix_width > width:
start_col = (matrix_width - width) // 2
matrix_width = width
# Crop the matrix to match the raster dimensions
resized_matrix = matrix[start_row:start_row + matrix_height, start_col:start_col + matrix_width]
# Write the matrix to the raster band
try:
gdal_array.BandWriteArray(band, resized_matrix, 0, 0)
except ValueError as e:
raise QgsProcessingException(f"Failed to write array to raster: {str(e)}")
# Set the NoData value
band.SetNoDataValue(0)
# Ensure that the minimum and maximum values are updated
band.ComputeStatistics(False)
band.SetStatistics(0, 1, 0.5, 0.5)
# Flush cache and close the raster
band.FlushCache()
raster.FlushCache()
raster = None
feedback.pushInfo(f"Raster written to {file_path}")
def addRasterLayer(self, file_path, layer_name, group, tif_type, context):
"""Añadir la capa raster al grupo en el proyecto."""
layer = QgsRasterLayer(file_path, layer_name, "gdal")
if not layer.isValid():
raise QgsProcessingException(f"Failed to load raster layer from {file_path}")
QgsProject.instance().addMapLayer(layer, False)
group.addLayer(layer)
# Si el nombre de la capa contiene "FireScar", cambiar el renderer a singleband pseudocolor
if tif_type == "FireScar":
# Forzar el cálculo de las estadísticas de la banda para obtener los valores correctos
provider = layer.dataProvider()
stats = provider.bandStatistics(1, QgsRasterMinMaxOrigin.Estimated)
min_value = stats.minimumValue
max_value = stats.maximumValue
# Crear un shader de color para interpolar entre colores
shader = QgsRasterShader()
color_ramp_shader = QgsColorRampShader(minimumValue=min_value, maximumValue=max_value)
color_ramp_shader.setColorRampType(QgsColorRampShader.Interpolated)
# Usar el estilo "Reds" de la lista de estilos de QGIS
style = QgsStyle().defaultStyle()
ramp = style.colorRamp('Reds')
if ramp:
color_ramp_shader.setSourceColorRamp(ramp)
shader.setRasterShaderFunction(color_ramp_shader)
# Crear el renderer con el shader
renderer = QgsSingleBandPseudoColorRenderer(layer.dataProvider(), 1, shader)
# Asignar el renderer a la capa
layer.setRenderer(renderer)
# Actualizar el rango de contraste para asegurarse de que se muestren correctamente
layer.setContrastEnhancement(QgsContrastEnhancement.StretchToMinimumMaximum)
# Forzar la actualización del renderizador
layer.triggerRepaint()
layer.reload()
def name(self):
return "firescarmapper"
def displayName(self):
return self.tr("Fire Scar Mapper")
def tr(self, string):
return QCoreApplication.translate("Processing", string)
def createInstance(self):
return ProcessingAlgorithm()