Commit 9851d5ea by Gencer

ResNet pre-trained models for PyTorch are added

parent 9dd9a4bc
......@@ -23,27 +23,30 @@ If you are interested in the new nomenclature of a smaller number of classes (Bi
We provide code and model weights for the following deep learning models that have been pre-trained on BigEarthNet with the original Level-3 class nomenclature of CLC 2018 for scene classification:
| Model Names | Pre-Trained TensorFlow Models | F<sub>1</sub> Score |
| Model Names | Pre-Trained TensorFlow Models | Pre-Trained PyTorch models |
| ------------- |-------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------|
| K-Branch CNN | [http://bigearth.net/static/pretrained-models/original_labels/K-BranchCNN.zip](http://bigearth.net/static/pretrained-models/original_labels/K-BranchCNN.zip)| 74.35% |
| VGG16 | [http://bigearth.net/static/pretrained-models/original_labels/VGG16.zip](http://bigearth.net/static/pretrained-models/original_labels/VGG16.zip) | 78.18% |
| VGG19 | [http://bigearth.net/static/pretrained-models/original_labels/VGG19.zip](http://bigearth.net/static/pretrained-models/original_labels/VGG19.zip) | 77.30% |
| ResNet50 | [http://bigearth.net/static/pretrained-models/original_labels/ResNet50.zip](http://bigearth.net/static/pretrained-models/original_labels/ResNet50.zip) | 75.00% |
| ResNet101 | [http://bigearth.net/static/pretrained-models/original_labels/ResNet101.zip](http://bigearth.net/static/pretrained-models/original_labels/ResNet101.zip) | 72.76% |
| ResNet152 | [http://bigearth.net/static/pretrained-models/original_labels/ResNet152.zip](http://bigearth.net/static/pretrained-models/original_labels/ResNet152.zip) | 75.86% |
| K-Branch CNN | [K-BranchCNN.zip](http://bigearth.net/static/pretrained-models/original_labels/K-BranchCNN.zip) | Coming soon|
| VGG16 | [VGG16.zip](http://bigearth.net/static/pretrained-models/original_labels/VGG16.zip) | Coming soon|
| VGG19 | [VGG19.zip](http://bigearth.net/static/pretrained-models/original_labels/VGG19.zip) | Coming soon|
| ResNet50 | [ResNet50.zip](http://bigearth.net/static/pretrained-models/original_labels/ResNet50.zip) | [ResNet50.pth.tar](http://bigearth.net/static/pretrained-models-pytorch/original_labels/ResNet50.pth.tar) |
| ResNet101 | [ResNet101.zip](http://bigearth.net/static/pretrained-models/original_labels/ResNet101.zip) | [ResNet101.pth.tar](http://bigearth.net/static/pretrained-models-pytorch/original_labels/ResNet101.pth.tar)|
| ResNet152 | [ResNet152.zip](http://bigearth.net/static/pretrained-models/original_labels/ResNet152.zip) | [ResNet152.pth.tar](http://bigearth.net/static/pretrained-models-pytorch/original_labels/ResNet152.pth.tar)|
The results provided in the [BigEarthNet paper](http://bigearth.net/static/documents/BigEarthNet_IGARSS_2019.pdf) are different from those given above due to the selection of different train, validation and test sets.
The results provided in the [BigEarthNet paper](http://bigearth.net/static/documents/BigEarthNet_IGARSS_2019.pdf) are different from those obtained by the models given above due to the selection of different train, validation and test sets.
The TensorFlow code for these models can be found [here](https://gitlab.tu-berlin.de/rsim/bigearthnet-models-tf).
The PyTorch code for these models can be found [here](https://gitlab.tu-berlin.de/rsim/bigearthnet-models-pytorch).
# Generation of Training/Test/Validation Splits
After downloading the raw images from https://www.bigearth.net, they need to be prepared for your ML application. We provide the script `prep_splits.py` for this purpose. It generates consumable data files (i.e., TFRecord) for training, validation and test splits which are suitable to use with TensorFlow. Suggested splits can be found with corresponding csv files under `splits` folder. The following command line arguments for `prep_splits.py` can be specified:
After downloading the raw images from https://www.bigearth.net, they need to be prepared for your ML application. We provide the script `prep_splits.py` for this purpose. It generates consumable data files (i.e., TFRecord) for training, validation and test splits which are suitable to use with TensorFlow or PyTorch. Suggested splits can be found with corresponding csv files under `splits` folder. The following command line arguments for `prep_splits.py` can be specified:
* `-r` or `--root_folder`: The root folder containing the raw images you have previously downloaded.
* `-o` or `--out_folder`: The output folder where the resulting files will be created.
* `-n` or `--splits`: A list of CSV files each of which contains the patch names of corresponding split.
* `-l` or `--library`: A flag to indicate for which ML library data files will be prepared: TensorFlow or PyTorch.
To run the script, either the GDAL or the rasterio package should be installed. The TensorFlow package should also be installed. The script is tested with Python 2.7, TensorFlow 1.3 and Ubuntu 16.04.
To run the script, either the GDAL or the rasterio package should be installed. The TensorFlow package should also be installed. The script is tested with Python 2.7, TensorFlow 1.3, PyTorch 1.2 and Ubuntu 16.04.
**Note**: BigEarthNet patches with high density snow, cloud and cloud shadow are not included in the training, test and validation sets constructed by the provided scripts (see the list of patches with seasonal snow [here](http://bigearth.net/static/documents/patches_with_seasonal_snow.csv) and that of cloud and cloud shadow [here](http://bigearth.net/static/documents/patches_with_cloud_and_shadow.csv)).
......@@ -53,9 +56,18 @@ Authors
**Gencer Sümbül**
http://www.user.tu-berlin.de/gencersumbul/
**Jian Kang**
https://www.rsim.tu-berlin.de/menue/team/dring_jian_kang/
**Tristan Kreuziger**
https://www.rsim.tu-berlin.de/menue/team/tristan_kreuziger/
Maintained by
-------
**Gencer Sümbül** for TensorFlow models
**Jian Kang** for PyTorch models
# License
The BigEarthNet Archive is licensed under the **Community Data License Agreement – Permissive, Version 1.0** ([Text](https://cdla.io/permissive-1-0/)).
......
......@@ -6,8 +6,8 @@
#
# prep_splits.py --help can be used to learn how to use this script.
#
# Author: Gencer Sumbul, http://www.user.tu-berlin.de/gencersumbul/
# Email: gencer.suembuel@tu-berlin.de
# Author: Gencer Sumbul, http://www.user.tu-berlin.de/gencersumbul/, Jian Kang, https://www.rsim.tu-berlin.de/menue/team/dring_jian_kang/
# Email: gencer.suembuel@tu-berlin.de, jian.kang@tu-berlin.de
# Date: 16 Dec 2019
# Version: 1.0.1
# Usage: prep_splits.py [-h] [-r ROOT_FOLDER] [-o OUT_FOLDER]
......@@ -18,103 +18,25 @@ import argparse
import os
import csv
import json
from tensorflow_utils import prep_tf_record_files
from pytorch_utils import prep_lmdb_files
# Spectral band names to read related GeoTIFF files
band_names = ['B01', 'B02', 'B03', 'B04', 'B05',
'B06', 'B07', 'B08', 'B8A', 'B09', 'B11', 'B12']
GDAL_EXISTED = False
RASTERIO_EXISTED = False
with open('label_indices.json', 'rb') as f:
label_indices = json.load(f)
def prep_example(bands, original_labels, original_labels_multi_hot, patch_name):
return tf.train.Example(
features=tf.train.Features(
feature={
'B01': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B01']))),
'B02': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B02']))),
'B03': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B03']))),
'B04': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B04']))),
'B05': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B05']))),
'B06': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B06']))),
'B07': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B07']))),
'B08': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B08']))),
'B8A': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B8A']))),
'B09': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B09']))),
'B11': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B11']))),
'B12': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B12']))),
'original_labels': tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[i.encode('utf-8') for i in original_labels])),
'original_labels_multi_hot': tf.train.Feature(
int64_list=tf.train.Int64List(value=original_labels_multi_hot)),
'patch_name': tf.train.Feature(
bytes_list=tf.train.BytesList(value=[patch_name.encode('utf-8')]))
}))
def create_split(root_folder, patch_names, TFRecord_writer):
progress_bar = tf.contrib.keras.utils.Progbar(target = len(patch_names))
for patch_idx, patch_name in enumerate(patch_names):
patch_folder_path = os.path.join(root_folder, patch_name)
bands = {}
for band_name in band_names:
# First finds related GeoTIFF path and reads values as an array
band_path = os.path.join(
patch_folder_path, patch_name + '_' + band_name + '.tif')
if GDAL_EXISTED:
band_ds = gdal.Open(band_path, gdal.GA_ReadOnly)
raster_band = band_ds.GetRasterBand(1)
band_data = raster_band.ReadAsArray()
bands[band_name] = np.array(band_data)
elif RASTERIO_EXISTED:
band_ds = rasterio.open(band_path)
band_data = np.array(band_ds.read(1))
bands[band_name] = np.array(band_data)
original_labels_multi_hot = np.zeros(
len(label_indices['original_labels'].keys()), dtype=int)
patch_json_path = os.path.join(
patch_folder_path, patch_name + '_labels_metadata.json')
with open(patch_json_path, 'rb') as f:
patch_json = json.load(f)
original_labels = patch_json['labels']
for label in original_labels:
original_labels_multi_hot[label_indices['original_labels'][label]] = 1
example = prep_example(
bands,
original_labels,
original_labels_multi_hot,
patch_name
)
TFRecord_writer.write(example.SerializeToString())
progress_bar.update(patch_idx)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=
'This script creates TFRecord files for the BigEarthNet train, validation and test splits')
parser.add_argument('-r', '--root_folder', dest = 'root_folder',
help = 'root folder path contains multiple patch folders')
parser.add_argument('-o', '--out_folder', dest = 'out_folder',
help = 'folder path containing resulting TFRecord files')
help = 'folder path containing resulting TFRecord or LMDB files')
parser.add_argument('-n', '--splits', dest = 'splits', help =
'csv files each of which contain list of patch names, patches with snow, clouds, and shadows already excluded', nargs = '+')
parser.add_argument('-l', '--library', type=str, dest = 'library', help="Limit search to Sentinel mission", choices=['tensorflow', 'pytorch'])
args = parser.parse_args()
......@@ -141,11 +63,7 @@ if __name__ == "__main__":
except ImportError:
print('ERROR: please install either GDAL or rasterio package to read GeoTIFF files')
exit()
try:
import tensorflow as tf
except ImportError:
print('ERROR: please install tensorflow package to create TFRecord files')
exit()
try:
import numpy as np
except ImportError:
......@@ -165,24 +83,30 @@ if __name__ == "__main__":
patch_names_list[-1].append(row[0].strip())
except:
print('ERROR: some csv files either do not exist or have been corrupted')
exit()
exit()
if args.library == 'tensorflow':
try:
writer_list = []
for split_name in split_names:
writer_list.append(
tf.python_io.TFRecordWriter(os.path.join(
args.out_folder, split_name + '.tfrecord'))
)
except:
print('ERROR: TFRecord writer is not able to write files')
import tensorflow as tf
except ImportError:
print('ERROR: please install tensorflow package to create TFRecord files')
exit()
for split_idx in range(len(patch_names_list)):
print('INFO: creating the split of', split_names[split_idx], 'is started')
create_split(
prep_tf_record_files(
args.root_folder,
args.out_folder,
split_names,
patch_names_list,
label_indices,
GDAL_EXISTED,
RASTERIO_EXISTED)
elif args.library == 'pytorch':
prep_lmdb_files(
args.root_folder,
patch_names_list[split_idx],
writer_list[split_idx]
)
writer_list[split_idx].close()
args.out_folder,
patch_names_list,
GDAL_EXISTED,
RASTERIO_EXISTED
)
\ No newline at end of file
import json
import csv
import os
import numpy as np
from collections import defaultdict
# original labels
LABELS = [
'Continuous urban fabric',
'Discontinuous urban fabric',
'Industrial or commercial units',
'Road and rail networks and associated land',
'Port areas',
'Airports',
'Mineral extraction sites',
'Dump sites',
'Construction sites',
'Green urban areas',
'Sport and leisure facilities',
'Non-irrigated arable land',
'Permanently irrigated land',
'Rice fields',
'Vineyards',
'Fruit trees and berry plantations',
'Olive groves',
'Pastures',
'Annual crops associated with permanent crops',
'Complex cultivation patterns',
'Land principally occupied by agriculture, with significant areas of natural vegetation',
'Agro-forestry areas',
'Broad-leaved forest',
'Coniferous forest',
'Mixed forest',
'Natural grassland',
'Moors and heathland',
'Sclerophyllous vegetation',
'Transitional woodland/shrub',
'Beaches, dunes, sands',
'Bare rock',
'Sparsely vegetated areas',
'Burnt areas',
'Inland marshes',
'Peatbogs',
'Salt marshes',
'Salines',
'Intertidal flats',
'Water courses',
'Water bodies',
'Coastal lagoons',
'Estuaries',
'Sea and ocean'
]
# the new labels
NEW_LABELS = [
'Urban fabric',
'Industrial or commercial units',
'Arable land',
'Permanent crops',
'Pastures',
'Complex cultivation patterns',
'Land principally occupied by agriculture, with significant areas of natural vegetation',
'Agro-forestry areas',
'Broad-leaved forest',
'Coniferous forest',
'Mixed forest',
'Natural grassland and sparsely vegetated areas',
'Moors, heathland and sclerophyllous vegetation',
'Transitional woodland/shrub',
'Beaches, dunes, sands',
'Inland wetlands',
'Coastal wetlands',
'Inland waters',
'Marine waters'
]
# removed labels from the original 43 labels
REMOVED_LABELS = [
'Road and rail networks and associated land',
'Port areas',
'Airports',
'Mineral extraction sites',
'Dump sites',
'Construction sites',
'Green urban areas',
'Sport and leisure facilities',
'Bare rock',
'Burnt areas',
'Intertidal flats'
]
# merged labels
GROUP_LABELS = {
'Continuous urban fabric':'Urban fabric',
'Discontinuous urban fabric':'Urban fabric',
'Non-irrigated arable land':'Arable land',
'Permanently irrigated land':'Arable land',
'Rice fields':'Arable land',
'Vineyards':'Permanent crops',
'Fruit trees and berry plantations':'Permanent crops',
'Olive groves':'Permanent crops',
'Annual crops associated with permanent crops':'Permanent crops',
'Natural grassland':'Natural grassland and sparsely vegetated areas',
'Sparsely vegetated areas':'Natural grassland and sparsely vegetated areas',
'Moors and heathland':'Moors, heathland and sclerophyllous vegetation',
'Sclerophyllous vegetation':'Moors, heathland and sclerophyllous vegetation',
'Inland marshes':'Inland wetlands',
'Peatbogs':'Inland wetlands',
'Salt marshes':'Coastal wetlands',
'Salines':'Coastal wetlands',
'Water bodies':'Inland waters',
'Water courses':'Inland waters',
'Coastal lagoons':'Marine waters',
'Estuaries':'Marine waters',
'Sea and ocean':'Marine waters'
}
def multiHot2cls(multiHotCode):
"""
multi hot labe to list of label_set
"""
pos = np.where(np.squeeze(multiHotCode))[0].tolist()
return np.array(LABELS)[pos,]
def cls2multiHot_new(cls_vec):
"""
create new multi hot label
"""
tmp = np.zeros((len(NEW_LABELS),))
for cls_nm in cls_vec:
if cls_nm in GROUP_LABELS:
tmp[NEW_LABELS.index(GROUP_LABELS[cls_nm])] = 1
elif cls_nm not in set(NEW_LABELS):
continue
else:
tmp[NEW_LABELS.index(cls_nm)] = 1
return tmp
def cls2multiHot_old(cls_vec):
"""
create old multi hot label
"""
tmp = np.zeros((len(LABELS),))
for cls_nm in cls_vec:
tmp[LABELS.index(cls_nm)] = 1
return tmp
def read_scale_raster(file_path, GDAL_EXISTED, RASTERIO_EXISTED):
"""
read raster file with specified scale
:param file_path:
:param scale:
:return:
"""
if GDAL_EXISTED:
import gdal
elif RASTERIO_EXISTED:
import rasterio
if GDAL_EXISTED:
band_ds = gdal.Open(file_path, gdal.GA_ReadOnly)
raster_band = band_ds.GetRasterBand(1)
band_data = raster_band.ReadAsArray()
elif RASTERIO_EXISTED:
band_ds = rasterio.open(file_path)
band_data = np.array(band_ds.read(1))
return band_data
def parse_json_labels(f_j_path):
"""
parse meta-data json file for big earth to get image labels
:param f_j_path: json file path
:return:
"""
with open(f_j_path, 'r') as f_j:
j_f_c = json.load(f_j)
return j_f_c['labels']
class dataGenBigEarthTiff:
def __init__(self, bigEarthDir=None,
bands10=None, bands20=None, bands60=None,
patch_names_list=None,
RASTERIO_EXISTED=None, GDAL_EXISTED=None
):
self.bigEarthDir = bigEarthDir
self.bands10 = bands10
self.bands20 = bands20
self.bands60 = bands60
self.GDAL_EXISTED = GDAL_EXISTED
self.RASTERIO_EXISTED = RASTERIO_EXISTED
self.total_patch = patch_names_list[0] + patch_names_list[1] + patch_names_list[2]
def __len__(self):
return len(self.total_patch)
def __getitem__(self, index):
return self.__data_generation(index)
def __data_generation(self, idx):
imgNm = self.total_patch[idx]
bands10_array = []
bands20_array = []
bands60_array = []
if self.bands10 is not None:
for band in self.bands10:
bands10_array.append(read_scale_raster(os.path.join(self.bigEarthDir, imgNm, imgNm+'_B'+band+'.tif'), self.GDAL_EXISTED, self.RASTERIO_EXISTED))
if self.bands20 is not None:
for band in self.bands20:
bands20_array.append(read_scale_raster(os.path.join(self.bigEarthDir, imgNm, imgNm+'_B'+band+'.tif'), self.GDAL_EXISTED, self.RASTERIO_EXISTED))
if self.bands60 is not None:
for band in self.bands60:
bands60_array.append(read_scale_raster(os.path.join(self.bigEarthDir, imgNm, imgNm+'_B'+band+'.tif'), self.GDAL_EXISTED, self.RASTERIO_EXISTED))
bands10_array = np.asarray(bands10_array).astype(np.float32)
bands20_array = np.asarray(bands20_array).astype(np.float32)
bands60_array = np.asarray(bands60_array).astype(np.float32)
labels = parse_json_labels(os.path.join(self.bigEarthDir, imgNm, imgNm+'_labels_metadata.json'))
oldMultiHots = cls2multiHot_old(labels)
oldMultiHots.astype(int)
newMultiHots = cls2multiHot_new(labels)
newMultiHots.astype(int)
sample = {'bands10': bands10_array, 'bands20': bands20_array, 'bands60': bands60_array,
'patch_name': imgNm, 'multi_hots_n':newMultiHots, 'multi_hots_o':oldMultiHots}
return sample
def dumps_pyarrow(obj):
"""
Serialize an object.
Returns:
Implementation-dependent bytes-like object
"""
import pyarrow as pa
return pa.serialize(obj).to_buffer()
def prep_lmdb_files(root_folder, out_folder, patch_names_list, GDAL_EXISTED, RASTERIO_EXISTED):
from torch.utils.data import DataLoader
import lmdb
dataGen = dataGenBigEarthTiff(
bigEarthDir = root_folder,
bands10 = ['02', '03', '04', '08'],
bands20 = ['05', '06', '07', '8A', '11', '12'],
bands60 = ['01','09'],
patch_names_list=patch_names_list,
GDAL_EXISTED=GDAL_EXISTED,
RASTERIO_EXISTED=RASTERIO_EXISTED
)
nSamples = len(dataGen)
map_size_ = (dataGen[0]['bands10'].nbytes + dataGen[0]['bands20'].nbytes + dataGen[0]['bands60'].nbytes)*10*len(dataGen)
data_loader = DataLoader(dataGen, num_workers=4, collate_fn=lambda x: x)
db = lmdb.open(os.path.join(out_folder, 'BigEarth_v1_4pt_org.lmdb'), map_size=map_size_)
txn = db.begin(write=True)
patch_names = []
for idx, data in enumerate(data_loader):
bands10, bands20, bands60, patch_name, _, multiHots_o = data[0]['bands10'], data[0]['bands20'], data[0]['bands60'], data[0]['patch_name'], data[0]['multi_hots_n'], data[0]['multi_hots_o']
# txn.put(u'{}'.format(patch_name).encode('ascii'), dumps_pyarrow((bands10, bands20, bands60, multiHots_n, multiHots_o)))
txn.put(u'{}'.format(patch_name).encode('ascii'), dumps_pyarrow((bands10, bands20, bands60, multiHots_o)))
patch_names.append(patch_name)
if idx % 10000 == 0:
print("[%d/%d]" % (idx, nSamples))
txn.commit()
txn = db.begin(write=True)
txn.commit()
keys = [u'{}'.format(patch_name).encode('ascii') for patch_name in patch_names]
with db.begin(write=True) as txn:
txn.put(b'__keys__', dumps_pyarrow(keys))
txn.put(b'__len__', dumps_pyarrow(len(keys)))
print("Flushing database ...")
db.sync()
db.close()
import tensorflow as tf
import numpy as np
import os
import json
# Spectral band names to read related GeoTIFF files
band_names = ['B01', 'B02', 'B03', 'B04', 'B05',
'B06', 'B07', 'B08', 'B8A', 'B09', 'B11', 'B12']
def prep_example(bands, original_labels, original_labels_multi_hot, patch_name):
return tf.train.Example(
features=tf.train.Features(
feature={
'B01': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B01']))),
'B02': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B02']))),
'B03': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B03']))),
'B04': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B04']))),
'B05': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B05']))),
'B06': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B06']))),
'B07': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B07']))),
'B08': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B08']))),
'B8A': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B8A']))),
'B09': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B09']))),
'B11': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B11']))),
'B12': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B12']))),
'original_labels': tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[i.encode('utf-8') for i in original_labels])),
'original_labels_multi_hot': tf.train.Feature(
int64_list=tf.train.Int64List(value=original_labels_multi_hot)),
'patch_name': tf.train.Feature(
bytes_list=tf.train.BytesList(value=[patch_name.encode('utf-8')]))
}))
def create_split(root_folder, patch_names, TFRecord_writer, label_indices, GDAL_EXISTED, RASTERIO_EXISTED):
if GDAL_EXISTED:
import gdal
elif RASTERIO_EXISTED:
import rasterio
progress_bar = tf.contrib.keras.utils.Progbar(target = len(patch_names))
for patch_idx, patch_name in enumerate(patch_names):
patch_folder_path = os.path.join(root_folder, patch_name)
bands = {}
for band_name in band_names:
# First finds related GeoTIFF path and reads values as an array
band_path = os.path.join(
patch_folder_path, patch_name + '_' + band_name + '.tif')
if GDAL_EXISTED:
band_ds = gdal.Open(band_path, gdal.GA_ReadOnly)
raster_band = band_ds.GetRasterBand(1)
band_data = raster_band.ReadAsArray()
bands[band_name] = np.array(band_data)
elif RASTERIO_EXISTED:
band_ds = rasterio.open(band_path)
band_data = np.array(band_ds.read(1))
bands[band_name] = np.array(band_data)
original_labels_multi_hot = np.zeros(
len(label_indices['original_labels'].keys()), dtype=int)
patch_json_path = os.path.join(
patch_folder_path, patch_name + '_labels_metadata.json')
with open(patch_json_path, 'rb') as f:
patch_json = json.load(f)
original_labels = patch_json['labels']
for label in original_labels:
original_labels_multi_hot[label_indices['original_labels'][label]] = 1
example = prep_example(
bands,
original_labels,
original_labels_multi_hot,
patch_name
)
TFRecord_writer.write(example.SerializeToString())
progress_bar.update(patch_idx)
def prep_tf_record_files(root_folder, out_folder, split_names, patch_names_list, label_indices, GDAL_EXISTED, RASTERIO_EXISTED):
try:
writer_list = []
for split_name in split_names:
writer_list.append(
tf.python_io.TFRecordWriter(os.path.join(
out_folder, split_name + '.tfrecord'))
)
except:
print('ERROR: TFRecord writer is not able to write files')
exit()
for split_idx in range(len(patch_names_list)):
print('INFO: creating the split of', split_names[split_idx], 'is started')
create_split(
root_folder,
patch_names_list[split_idx],
writer_list[split_idx],
label_indices,
GDAL_EXISTED,
RASTERIO_EXISTED,
)
writer_list[split_idx].close()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment